diff --git a/.gitignore b/.gitignore index fab9f00a5..4ec4f4bd6 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ *.pyd *.log *.csv +*.parquet *.c *.cpp *.DS_Store @@ -12,15 +13,18 @@ .vs/ build/ log/ +logs/ +checkpoint/ +checkpoints/ +streamit/ dist/ *.egg-info/ tools/schedule docs/_build -test/ -data/ .eggs/ maro_venv/ pyvenv.cfg htmlcov/ -.coverage -.coveragerc +.coverage +.coveragerc +.tmp/ diff --git a/docker_files/dev.df b/docker_files/dev.df new file mode 100644 index 000000000..00677003e --- /dev/null +++ b/docker_files/dev.df @@ -0,0 +1,36 @@ +FROM python:3.7-buster +WORKDIR /maro + +# Install Apt packages +RUN apt-get update --fix-missing +RUN apt-get install -y apt-utils +RUN apt-get install -y sudo +RUN apt-get install -y gcc +RUN apt-get install -y libcurl4 libcurl4-openssl-dev libssl-dev curl +RUN apt-get install -y libzmq3-dev +RUN apt-get install -y python3-pip +RUN apt-get install -y python3-dev libpython3.7-dev python-numpy +RUN rm -rf /var/lib/apt/lists/* + +# Install Python packages +RUN pip install --upgrade pip +RUN pip install --no-cache-dir Cython==0.29.14 +RUN pip install --no-cache-dir pyaml==20.4.0 +RUN pip install --no-cache-dir pyzmq==19.0.2 +RUN pip install --no-cache-dir numpy==1.19.1 +RUN pip install --no-cache-dir matplotlib +RUN pip install --no-cache-dir torch==1.6.0 +RUN pip install --no-cache-dir scipy +RUN pip install --no-cache-dir matplotlib +RUN pip install --no-cache-dir redis +RUN pip install --no-cache-dir networkx + +COPY maro /maro/maro +COPY scripts /maro/scripts/ +COPY setup.py /maro/ +RUN bash /maro/scripts/install_maro.sh +RUN pip cache purge + +ENV PYTHONPATH=/maro + +CMD ["/bin/bash"] diff --git a/docs/source/apidoc/maro.rl.rst b/docs/source/apidoc/maro.rl.rst index d521dd652..aa3bc47df 100644 --- a/docs/source/apidoc/maro.rl.rst +++ b/docs/source/apidoc/maro.rl.rst @@ -1,198 +1,330 @@ -Agent +Distributed ================================================================================ -maro.rl.agent.abs\_agent +maro.rl.distributed.abs_proxy -------------------------------------------------------------------------------- -.. automodule:: maro.rl.agent.abs_agent +.. automodule:: maro.rl.distributed.abs_proxy :members: :undoc-members: :show-inheritance: -maro.rl.agent.dqn +maro.rl.distributed.abs_worker -------------------------------------------------------------------------------- -.. automodule:: maro.rl.agent.dqn +.. automodule:: maro.rl.distributed.abs_worker :members: :undoc-members: :show-inheritance: -maro.rl.agent.ddpg +Exploration +================================================================================ + +maro.rl.exploration.scheduling -------------------------------------------------------------------------------- -.. automodule:: maro.rl.agent.ddpg +.. automodule:: maro.rl.exploration.scheduling :members: :undoc-members: :show-inheritance: -maro.rl.agent.policy\_optimization +maro.rl.exploration.strategies -------------------------------------------------------------------------------- -.. automodule:: maro.rl.agent.policy_optimization +.. automodule:: maro.rl.exploration.strategies :members: :undoc-members: :show-inheritance: - -Agent Manager +Model ================================================================================ -maro.rl.agent.abs\_agent\_manager +maro.rl.model.algorithm_nets -------------------------------------------------------------------------------- -.. automodule:: maro.rl.agent.abs_agent_manager +.. automodule:: maro.rl.model.algorithm_nets :members: :undoc-members: :show-inheritance: +maro.rl.model.abs_net +-------------------------------------------------------------------------------- -Model -================================================================================ +.. automodule:: maro.rl.model.abs_net + :members: + :undoc-members: + :show-inheritance: + +maro.rl.model.fc_block +-------------------------------------------------------------------------------- + +.. automodule:: maro.rl.model.fc_block + :members: + :undoc-members: + :show-inheritance: + +maro.rl.model.multi_q_net +-------------------------------------------------------------------------------- + +.. automodule:: maro.rl.model.multi_q_net + :members: + :undoc-members: + :show-inheritance: + +maro.rl.model.policy_net +-------------------------------------------------------------------------------- + +.. automodule:: maro.rl.model.policy_net + :members: + :undoc-members: + :show-inheritance: -maro.rl.model.learning\_model +maro.rl.model.q_net -------------------------------------------------------------------------------- -.. automodule:: maro.rl.model.torch.learning_model +.. automodule:: maro.rl.model.q_net :members: :undoc-members: :show-inheritance: +maro.rl.model.v_net +-------------------------------------------------------------------------------- -Explorer +.. automodule:: maro.rl.model.v_net + :members: + :undoc-members: + :show-inheritance: + +Policy ================================================================================ -maro.rl.exploration.abs\_explorer +maro.rl.policy.abs_policy -------------------------------------------------------------------------------- -.. automodule:: maro.rl.exploration.abs_explorer +.. automodule:: maro.rl.policy.abs_policy :members: :undoc-members: :show-inheritance: -maro.rl.exploration.epsilon\_greedy\_explorer +maro.rl.policy.continuous_rl_policy -------------------------------------------------------------------------------- -.. automodule:: maro.rl.exploration.epsilon_greedy_explorer +.. automodule:: maro.rl.policy.continuous_rl_policy :members: :undoc-members: :show-inheritance: -maro.rl.exploration.noise\_explorer +maro.rl.policy.discrete_rl_policy -------------------------------------------------------------------------------- -.. automodule:: maro.rl.exploration.noise_explorer +.. automodule:: maro.rl.policy.discrete_rl_policy :members: :undoc-members: :show-inheritance: +RL Component +================================================================================ + +maro.rl.rl_component.rl_component_bundle +-------------------------------------------------------------------------------- -Scheduler +.. automodule:: maro.rl.rl_component.rl_component_bundle + :members: + :undoc-members: + :show-inheritance: + +Rollout ================================================================================ -maro.rl.scheduling.scheduler +maro.rl.rollout.batch_env_sampler -------------------------------------------------------------------------------- -.. automodule:: maro.rl.scheduling.scheduler +.. automodule:: maro.rl.rollout.batch_env_sampler :members: :undoc-members: :show-inheritance: -maro.rl.scheduling.simple\_parameter\_scheduler +maro.rl.rollout.env_sampler -------------------------------------------------------------------------------- -.. automodule:: maro.rl.scheduling.simple_parameter_scheduler +.. automodule:: maro.rl.rollout.env_sampler :members: :undoc-members: :show-inheritance: +maro.rl.rollout.worker +-------------------------------------------------------------------------------- -Shaping +.. automodule:: maro.rl.rollout.worker + :members: + :undoc-members: + :show-inheritance: + +Training ================================================================================ -maro.rl.shaping.abs\_shaper +maro.rl.training.algorithms -------------------------------------------------------------------------------- -.. automodule:: maro.rl.shaping.abs_shaper +.. automodule:: maro.rl.training.algorithms :members: :undoc-members: :show-inheritance: +maro.rl.training.proxy +-------------------------------------------------------------------------------- -Storage -================================================================================ +.. automodule:: maro.rl.training.proxy + :members: + :undoc-members: + :show-inheritance: -maro.rl.storage.abs\_store +maro.rl.training.replay_memory -------------------------------------------------------------------------------- -.. automodule:: maro.rl.storage.abs_store +.. automodule:: maro.rl.training.replay_memory :members: :undoc-members: :show-inheritance: -maro.rl.storage.simple\_store +maro.rl.training.trainer -------------------------------------------------------------------------------- -.. automodule:: maro.rl.storage.simple_store +.. automodule:: maro.rl.training.trainer :members: :undoc-members: :show-inheritance: +maro.rl.training.training_manager +-------------------------------------------------------------------------------- -Actor -================================================================================ +.. automodule:: maro.rl.training.training_manager + :members: + :undoc-members: + :show-inheritance: -maro.rl.actor.abs\_actor +maro.rl.training.train_ops -------------------------------------------------------------------------------- -.. automodule:: maro.rl.actor.abs_actor +.. automodule:: maro.rl.training.train_ops :members: :undoc-members: :show-inheritance: -maro.rl.actor.simple\_actor +maro.rl.training.utils -------------------------------------------------------------------------------- -.. automodule:: maro.rl.actor.simple_actor +.. automodule:: maro.rl.training.utils :members: :undoc-members: :show-inheritance: +maro.rl.training.worker +-------------------------------------------------------------------------------- -Learner +.. automodule:: maro.rl.training.worker + :members: + :undoc-members: + :show-inheritance: + +Utils ================================================================================ -maro.rl.learner.abs\_learner +maro.rl.utils.common +-------------------------------------------------------------------------------- + +.. automodule:: maro.rl.utils.common + :members: + :undoc-members: + :show-inheritance: + +maro.rl.utils.message_enums -------------------------------------------------------------------------------- -.. automodule:: maro.rl.learner.abs_learner +.. automodule:: maro.rl.utils.message_enums :members: :undoc-members: :show-inheritance: -maro.rl.learner.simple\_learner +maro.rl.utils.objects -------------------------------------------------------------------------------- -.. automodule:: maro.rl.learner.simple_learner +.. automodule:: maro.rl.utils.objects :members: :undoc-members: :show-inheritance: +maro.rl.utils.torch_utils +-------------------------------------------------------------------------------- -Distributed Topologies +.. automodule:: maro.rl.utils.torch_utils + :members: + :undoc-members: + :show-inheritance: + +maro.rl.utils.trajectory_computation +-------------------------------------------------------------------------------- + +.. automodule:: maro.rl.utils.trajectory_computation + :members: + :undoc-members: + :show-inheritance: + +maro.rl.utils.transition_batch +-------------------------------------------------------------------------------- + +.. automodule:: maro.rl.utils.transition_batch + :members: + :undoc-members: + :show-inheritance: + +Workflows ================================================================================ -maro.rl.dist\_topologies.common +maro.rl.workflows.config +-------------------------------------------------------------------------------- + +.. automodule:: maro.rl.workflows.config + :members: + :undoc-members: + :show-inheritance: + +maro.rl.workflows.main +-------------------------------------------------------------------------------- + +.. automodule:: maro.rl.workflows.main + :members: + :undoc-members: + :show-inheritance: + +maro.rl.workflows.rollout_worker +-------------------------------------------------------------------------------- + +.. automodule:: maro.rl.workflows.rollout_worker + :members: + :undoc-members: + :show-inheritance: + +maro.rl.workflows.scenario +-------------------------------------------------------------------------------- + +.. automodule:: maro.rl.workflows.scenario + :members: + :undoc-members: + :show-inheritance: + +maro.rl.workflows.train_proxy -------------------------------------------------------------------------------- -.. automodule:: maro.rl.dist_topologies.common +.. automodule:: maro.rl.workflows.train_proxy :members: :undoc-members: :show-inheritance: -maro.rl.dist\_topologies.single\_learner\_multi\_actor\_sync\_mode +maro.rl.workflows.train_worker -------------------------------------------------------------------------------- -.. automodule:: maro.rl.dist_topologies.single_learner_multi_actor_sync_mode +.. automodule:: maro.rl.workflows.train_worker :members: :undoc-members: :show-inheritance: diff --git a/docs/source/conf.py b/docs/source/conf.py index 0409844ec..e423c0781 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -100,3 +100,5 @@ } source_suffix = [".md", ".rst"] + +numfig = True diff --git a/docs/source/examples/citi_bike.rst b/docs/source/examples/citi_bike.rst deleted file mode 100644 index b9aae2e3d..000000000 --- a/docs/source/examples/citi_bike.rst +++ /dev/null @@ -1,75 +0,0 @@ -Example Scenario: Bike Repositioning (Citi Bike) -================================================ - -In this example we demonstrate using a simple greedy policy for `Citi Bike `_, -a real-world bike repositioning scenario. - -Greedy Policy -------------- - -Our greedy policy is simple: if the event type is supply, the policy will make -the current station send as many bikes as possible to one of k stations with the most empty docks. If the event type is -demand, the policy will make the current station request as many bikes as possible from one of k stations with the most -bikes. We use a heap data structure to find the top k supply/demand candidates from the action scope associated with -each decision event. - -.. code-block:: python - - class GreedyPolicy: - ... - def choose_action(self, decision_event: DecisionEvent): - if decision_event.type == DecisionType.Supply: - """ - Find k target stations with the most empty slots, randomly choose one of them and send as many bikes to - it as allowed by the action scope - """ - top_k_demands = [] - for demand_candidate, available_docks in decision_event.action_scope.items(): - if demand_candidate == decision_event.station_idx: - continue - - heapq.heappush(top_k_demands, (available_docks, demand_candidate)) - if len(top_k_demands) > self._demand_top_k: - heapq.heappop(top_k_demands) - - max_reposition, target_station_idx = random.choice(top_k_demands) - action = Action(decision_event.station_idx, target_station_idx, max_reposition) - else: - """ - Find k source stations with the most bikes, randomly choose one of them and request as many bikes from - it as allowed by the action scope. - """ - top_k_supplies = [] - for supply_candidate, available_bikes in decision_event.action_scope.items(): - if supply_candidate == decision_event.station_idx: - continue - - heapq.heappush(top_k_supplies, (available_bikes, supply_candidate)) - if len(top_k_supplies) > self._supply_top_k: - heapq.heappop(top_k_supplies) - - max_reposition, source_idx = random.choice(top_k_supplies) - action = Action(source_idx, decision_event.station_idx, max_reposition) - - return action - - -Interaction with the Greedy Policy ----------------------------------- - -This environment is driven by `real trip history data `_ from Citi Bike. - -.. code-block:: python - - env = Env(scenario=config.env.scenario, topology=config.env.topology, start_tick=config.env.start_tick, - durations=config.env.durations, snapshot_resolution=config.env.resolution) - - if config.env.seed is not None: - env.set_seed(config.env.seed) - - policy = GreedyPolicy(config.agent.supply_top_k, config.agent.demand_top_k) - metrics, decision_event, done = env.step(None) - while not done: - metrics, decision_event, done = env.step(policy.choose_action(decision_event)) - - env.reset() \ No newline at end of file diff --git a/docs/source/examples/multi_agent_dqn_cim.rst b/docs/source/examples/multi_agent_dqn_cim.rst deleted file mode 100644 index f85b3e2a9..000000000 --- a/docs/source/examples/multi_agent_dqn_cim.rst +++ /dev/null @@ -1,168 +0,0 @@ -Multi Agent DQN for CIM -================================================ - -This example demonstrates how to use MARO's reinforcement learning (RL) toolkit to solve the container -inventory management (CIM) problem. It is formalized as a multi-agent reinforcement learning problem, -where each port acts as a decision agent. When a vessel arrives at a port, these agents must take actions -by transferring a certain amount of containers to / from the vessel. The objective is for the agents to -learn policies that minimize the overall container shortage. - -Trajectory ----------- - -The ``CIMTrajectoryForDQN`` inherits from ``Trajectory`` function and implements methods to be used as callbacks -in the roll-out loop. In this example, - * ``get_state`` converts environment observations to state vectors that encode temporal and spatial information. - The temporal information includes relevant port and vessel information, such as shortage and remaining space, - over the past k days (here k = 7). The spatial information includes features of the downstream ports. - * ``get_action`` converts agents' output (an integer that maps to a percentage of containers to be loaded - to or unloaded from the vessel) to action objects that can be executed by the environment. - * ``get_offline_reward`` computes the reward of a given action as a linear combination of fulfillment and - shortage within a future time frame. - * ``on_finish`` processes a complete trajectory into data that can be used directly by the learning agents. - - -.. code-block:: python - class CIMTrajectoryForDQN(Trajectory): - def __init__( - self, env, *, port_attributes, vessel_attributes, action_space, look_back, max_ports_downstream, - reward_time_window, fulfillment_factor, shortage_factor, time_decay, - finite_vessel_space=True, has_early_discharge=True - ): - super().__init__(env) - self.port_attributes = port_attributes - self.vessel_attributes = vessel_attributes - self.action_space = action_space - self.look_back = look_back - self.max_ports_downstream = max_ports_downstream - self.reward_time_window = reward_time_window - self.fulfillment_factor = fulfillment_factor - self.shortage_factor = shortage_factor - self.time_decay = time_decay - self.finite_vessel_space = finite_vessel_space - self.has_early_discharge = has_early_discharge - - def get_state(self, event): - vessel_snapshots, port_snapshots = self.env.snapshot_list["vessels"], self.env.snapshot_list["ports"] - tick, port_idx, vessel_idx = event.tick, event.port_idx, event.vessel_idx - ticks = [max(0, tick - rt) for rt in range(self.look_back - 1)] - future_port_idx_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int') - port_features = port_snapshots[ticks: [port_idx] + list(future_port_idx_list): self.port_attributes] - vessel_features = vessel_snapshots[tick: vessel_idx: self.vessel_attributes] - return {port_idx: np.concatenate((port_features, vessel_features))} - - def get_action(self, action_by_agent, event): - vessel_snapshots = self.env.snapshot_list["vessels"] - action_info = list(action_by_agent.values())[0] - model_action = action_info[0] if isinstance(action_info, tuple) else action_info - scope, tick, port, vessel = event.action_scope, event.tick, event.port_idx, event.vessel_idx - zero_action_idx = len(self.action_space) / 2 # index corresponding to value zero. - vessel_space = vessel_snapshots[tick:vessel:self.vessel_attributes][2] if self.finite_vessel_space else float("inf") - early_discharge = vessel_snapshots[tick:vessel:"early_discharge"][0] if self.has_early_discharge else 0 - percent = abs(self.action_space[model_action]) - - if model_action < zero_action_idx: - action_type = ActionType.LOAD - actual_action = min(round(percent * scope.load), vessel_space) - elif model_action > zero_action_idx: - action_type = ActionType.DISCHARGE - plan_action = percent * (scope.discharge + early_discharge) - early_discharge - actual_action = round(plan_action) if plan_action > 0 else round(percent * scope.discharge) - else: - actual_action, action_type = 0, ActionType.LOAD - - return {port: Action(vessel, port, actual_action, action_type)} - - def get_offline_reward(self, event): - port_snapshots = self.env.snapshot_list["ports"] - start_tick = event.tick + 1 - ticks = list(range(start_tick, start_tick + self.reward_time_window)) - - future_fulfillment = port_snapshots[ticks::"fulfillment"] - future_shortage = port_snapshots[ticks::"shortage"] - decay_list = [ - self.time_decay ** i for i in range(self.reward_time_window) - for _ in range(future_fulfillment.shape[0] // self.reward_time_window) - ] - - tot_fulfillment = np.dot(future_fulfillment, decay_list) - tot_shortage = np.dot(future_shortage, decay_list) - - return np.float32(self.fulfillment_factor * tot_fulfillment - self.shortage_factor * tot_shortage) - - def on_env_feedback(self, event, state_by_agent, action_by_agent, reward): - self.trajectory["event"].append(event) - self.trajectory["state"].append(state_by_agent) - self.trajectory["action"].append(action_by_agent) - - def on_finish(self): - exp_by_agent = defaultdict(lambda: defaultdict(list)) - for i in range(len(self.trajectory["state"]) - 1): - agent_id = list(self.trajectory["state"][i].keys())[0] - exp = exp_by_agent[agent_id] - exp["S"].append(self.trajectory["state"][i][agent_id]) - exp["A"].append(self.trajectory["action"][i][agent_id]) - exp["R"].append(self.get_offline_reward(self.trajectory["event"][i])) - exp["S_"].append(list(self.trajectory["state"][i + 1].values())[0]) - - return dict(exp_by_agent) - - -Agent ------ - -The out-of-the-box DQN is used as our agent. - -.. code-block:: python - agent_config = { - "model": ..., - "optimization": ..., - "hyper_params": ... - } - - def get_dqn_agent(): - q_model = SimpleMultiHeadModel( - FullyConnectedBlock(**agent_config["model"]), optim_option=agent_config["optimization"] - ) - return DQN(q_model, DQNConfig(**agent_config["hyper_params"])) - - -Training --------- - -The distributed training consists of one learner process and multiple actor processes. The learner optimizes -the policy by collecting roll-out data from the actors to train the underlying agents. - -The actor process must create a roll-out executor for performing the requested roll-outs, which means that the -the environment simulator and shapers should be created here. In this example, inference is performed on the -actor's side, so a set of DQN agents must be created in order to load the models (and exploration parameters) -from the learner. - -.. code-block:: python - def cim_dqn_actor(): - env = Env(**training_config["env"]) - agent = MultiAgentWrapper({name: get_dqn_agent() for name in env.agent_idx_list}) - actor = Actor(env, agent, CIMTrajectoryForDQN, trajectory_kwargs=common_config) - actor.as_worker(training_config["group"]) - -The learner's side requires a concrete learner class that inherits from ``AbsLearner`` and implements the ``run`` -method which contains the main training loop. Here the implementation is similar to the single-threaded version -except that the ``collect`` method is used to obtain roll-out data from the actors (since the roll-out executors -are located on the actors' side). The agents created here are where training occurs and hence always contains the -latest policies. - -.. code-block:: python - def cim_dqn_learner(): - env = Env(**training_config["env"]) - agent = MultiAgentWrapper({name: get_dqn_agent() for name in env.agent_idx_list}) - scheduler = TwoPhaseLinearParameterScheduler(training_config["max_episode"], **training_config["exploration"]) - actor = ActorProxy( - training_config["group"], training_config["num_actors"], - update_trigger=training_config["learner_update_trigger"] - ) - learner = OffPolicyLearner(actor, scheduler, agent, **training_config["training"]) - learner.run() - -.. note:: - - All related code snippets are supported in `maro playground `_. diff --git a/docs/source/images/rl/agent.svg b/docs/source/images/rl/agent.svg deleted file mode 100644 index 8359bb8bb..000000000 --- a/docs/source/images/rl/agent.svg +++ /dev/null @@ -1,3 +0,0 @@ - - -
Loss Function
Loss Function
Model
Model
Model
Model
Loss Function
Loss Function
Agent
Agent
train
train
choose_action
choo...
Model
Model
 Component
Comp...
Experience Pool
(optional)
Experience Pool...
samples
samples
store
store
Viewer does not support full SVG 1.1
\ No newline at end of file diff --git a/docs/source/images/rl/distributed_training.svg b/docs/source/images/rl/distributed_training.svg new file mode 100644 index 000000000..546e80a5f --- /dev/null +++ b/docs/source/images/rl/distributed_training.svg @@ -0,0 +1,4 @@ + + + +
Trainer Manager
Trainer Manag...
Worker
Worker
Proxy
Proxy
Task
Task
Result
Result
Task
Task
Result
Result
Text is not SVG - cannot display
\ No newline at end of file diff --git a/docs/source/images/rl/env_sampler.svg b/docs/source/images/rl/env_sampler.svg new file mode 100644 index 000000000..04c09f31a --- /dev/null +++ b/docs/source/images/rl/env_sampler.svg @@ -0,0 +1,4 @@ + + + +
Environment Simulator
Environment S...
Environment Sampler
Environment Sampler
Agent
Agent
Agent
Agent
Policy
Policy
actions
actions
states (global and per agent)
states (global and per a...
rewards
rewards
Text is not SVG - cannot display
\ No newline at end of file diff --git a/docs/source/images/rl/learner_actor.svg b/docs/source/images/rl/learner_actor.svg deleted file mode 100644 index ae0495e19..000000000 --- a/docs/source/images/rl/learner_actor.svg +++ /dev/null @@ -1,3 +0,0 @@ - - -
Agent(s)
Agent(s)
Roll-out Executor
Roll-out Executor
Actor
Actor
Learner
Learner
roll-out request
roll-o...
roll-out result
roll-o...
train
train
decision request
decision r...
roll out
roll o...
Viewer does not support full SVG 1.1
\ No newline at end of file diff --git a/docs/source/images/rl/learning_workflow.svg b/docs/source/images/rl/learning_workflow.svg new file mode 100644 index 000000000..ead7a9b14 --- /dev/null +++ b/docs/source/images/rl/learning_workflow.svg @@ -0,0 +1,4 @@ + + + +
Roll-out Phase
Roll-out Phase
training data
training d...
Training Phase
Training Phase
updated policies
updated polic...
Text is not SVG - cannot display
\ No newline at end of file diff --git a/docs/source/images/rl/parallel_rollout.svg b/docs/source/images/rl/parallel_rollout.svg new file mode 100644 index 000000000..7b18e1928 --- /dev/null +++ b/docs/source/images/rl/parallel_rollout.svg @@ -0,0 +1,4 @@ + + + +
Worker
Worker
Roll-out Controller
Roll-out Controller
sample
sample
eval
eval
environment simulator
environment simul...
roll-out message
roll-out messa...
results
results
Text is not SVG - cannot display
\ No newline at end of file diff --git a/docs/source/images/rl/policy_manager.svg b/docs/source/images/rl/policy_manager.svg new file mode 100644 index 000000000..96cc1b31c --- /dev/null +++ b/docs/source/images/rl/policy_manager.svg @@ -0,0 +1,3 @@ + + +
Simple Policy Manager
Simple Policy Manager
policies
policies
update
update
get_state
get_state
Simple Policy Manager 
Simple Policy Manager 
Distributed Policy Manager
Distributed Policy Manager
Policy 
Policy 
update
update
get_state
get_state
Trainer (Grad Worker)
Trainer (Grad Worker)
Auto-balanced task dispatching message
Auto-balanced tas...
gradient
gradient
Policy State 1
Policy State 1
Experience Batch 1
Experience Batch 1
Task
Task
Policy State 2
Policy State 2
Experience Batch 2
Experience Batch 2
Task
Task
...
...
Distributed Policy Manager
Distributed Policy Manager
policies
policies
Policy
Policy
Trainer (Grad Worker)
Trainer (Grad Worker)
Auto-balanced task dispatching message
Auto-balanced tas...
gradient
gradient
Policy State 1
Policy State 1
Experience Batch 1
Experience Batch 1
Task
Task
Policy State 2
Policy State 2
Experience Batch 2
Experience Batch 2
Task
Task
...
...
Policy Host
Policy Host
Task Queue
Task Queue
Request Workers
Request Workers
Worker IDs
for training
Worker IDs...
Update Available
Worker List
Update Available...
Task Queue
Task Queue
Request Workers
Request Workers
Worker IDs
for training
Worker IDs...
Update Available
Worker List
Update Available...
Viewer does not support full SVG 1.1
diff --git a/docs/source/images/rl/policy_model_trainer.svg b/docs/source/images/rl/policy_model_trainer.svg new file mode 100644 index 000000000..74697e159 --- /dev/null +++ b/docs/source/images/rl/policy_model_trainer.svg @@ -0,0 +1,4 @@ + + + +
Policy








Policy...
Model
Model
Trainer








Trainer...
Experience
Memory
Experience...
Auxiliary
models
Auxiliary...
Training-related interfaces
Environment
Environment
StatesActionsRecord experiences
Monitor
Monitor
Text is not SVG - cannot display
\ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index 2b3fe29b4..2564203d1 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -89,7 +89,6 @@ Contents :maxdepth: 2 :caption: Examples - examples/multi_agent_dqn_cim.rst examples/greedy_policy_citi_bike.rst .. toctree:: diff --git a/docs/source/key_components/communication.rst b/docs/source/key_components/communication.rst index fa6926af1..bd501a051 100644 --- a/docs/source/key_components/communication.rst +++ b/docs/source/key_components/communication.rst @@ -43,7 +43,7 @@ The main attributes of a message instance include: message = Message(tag="check_in", source="worker_001", destination="master", - payload="") + body="") Session Message ^^^^^^^^^^^^^^^ @@ -71,13 +71,13 @@ The stages of each session are maintained internally by the proxy. task_message = SessionMessage(tag="sum", source="master", destination="worker_001", - payload=[0, 1, 2, ...], + body=[0, 1, 2, ...], session_type=SessionType.TASK) notification_message = SessionMessage(tag="check_out", source="worker_001", destination="master", - payload="", + body="", session_type=SessionType.NOTIFICATION) Communication Primitives diff --git a/docs/source/key_components/data_model.rst b/docs/source/key_components/data_model.rst index 00a1dd95b..28b9e910c 100644 --- a/docs/source/key_components/data_model.rst +++ b/docs/source/key_components/data_model.rst @@ -259,3 +259,799 @@ For better data access, we also provide some advanced features, including: # Also with dynamic implementation, we can get the const attributes which is shared between snapshot list, even without # any snapshot (need to provided one tick for padding). states = test_nodes_snapshots[0: [0, 1]: ["const_attribute", "const_attribute_2"]] + + + +States in built-in scenarios' snapshot list +------------------------------------------- + +.. TODO: move to environment part? + +Currently there are 3 ways to expose states in built-in scenarios: + +Summary +~~~~~~~~~~~ + +Summary(env.summary) is used to expose static states to outside, it provide 3 items by default: +node_mapping, node_detail and event payload. + +The "node_mapping" item usually contains node name and related index, but the structure may be different +for different scenario. + +The "node_detail" usually used to expose node definitions, like node name, attribute name and slot number, +this is useful if you want to know what attributes are support for a scenario. + +The "event_payload" used show that payload attributes of event in scenario, like "RETURN_FULL" event in +CIM scenario, it contains "src_port_idx", "dest_port_idx" and "quantity". + +Metrics +~~~~~~~ + +Metrics(env.metrics) is designed that used to expose raw states of reward since we have removed reward +support in v0.2 version, and it also can be used to export states that not supported by snapshot list, like dictionary or complex +structures. Currently there are 2 ways to get the metrics from environment: env.metrics, or 1st result from env.step. + +This metrics usually is a dictionary with several keys, but this is determined by business engine. + +Snapshot_list +~~~~~~~~~~~~~ + +Snapshot list is the history of nodes (or data model) for a scenario, it only support numberic data types now. +It supported slicing query with a numpy array, so it support batch operations, make it much faster than +using raw python objects. + +Nodes and attributes may different for different scenarios, following we will introduce about those in +built-in scenarios. + +NOTE: +Per tick state means that the attribute value will be reset to 0 after each step. + +CIM +--- + +Default settings for snapshot list +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Snapshot resolution: 1 + + +Max snapshot number: same as durations + +Nodes and attributes in scenario +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In CIM scenario, there are 3 node types: + + +port +++++ + +capacity +******** + +type: int +slots: 1 + +The capacity of port for stocking containers. + +empty +***** + +type: int +slots: 1 + +Empty container volume on the port. + +full +**** + +type: int +slots: 1 + +Laden container volume on the port. + +on_shipper +********** + +type: int +slots: 1 + +Empty containers, which are released to the shipper. + +on_consignee +************ + +type: int +slots: 1 + +Laden containers, which are delivered to the consignee. + +shortage +******** + +type: int +slots: 1 + +Per tick state. Shortage of empty container at current tick. + +acc_storage +*********** + +type: int +slots: 1 + +Accumulated shortage number to the current tick. + +booking +******* + +type: int +slots: 1 + +Per tick state. Order booking number of a port at the current tick. + +acc_booking +*********** + +type: int +slots: 1 + +Accumulated order booking number of a port to the current tick. + +fulfillment +*********** + +type: int +slots: 1 + +Fulfilled order number of a port at the current tick. + +acc_fulfillment +*************** + +type: int +slots: 1 + +Accumulated fulfilled order number of a port to the current tick. + +transfer_cost +************* + +type: float +slots: 1 + +Cost of transferring container, which also covers loading and discharging cost. + +vessel +++++++ + +capacity +******** + +type: int +slots: 1 + +The capacity of vessel for transferring containers. + +NOTE: +This attribute is ignored in current implementation. + +empty +***** + +type: int +slots: 1 + +Empty container volume on the vessel. + +full +**** + +type: int +slots: 1 + +Laden container volume on the vessel. + +remaining_space +*************** + +type: int +slots: 1 + +Remaining space of the vessel. + +early_discharge +*************** + +type: int +slots: 1 + +Discharged empty container number for loading laden containers. + +route_idx +********* + +type: int +slots: 1 + +Which route current vessel belongs to. + +last_loc_idx +************ + +type: int +slots: 1 + +Last stop port index in route, it is used to identify where is current vessel. + +next_loc_idx +************ + +type: int +slots: 1 + +Next stop port index in route, it is used to identify where is current vessel. + +past_stop_list +************** + +type: int +slots: dynamic + +NOTE: +This and following attribute are special, that its slot number is determined by configuration, +but different with a list attribute, its slot number is fixed at runtime. + +Stop indices that we have stopped in the past. + +past_stop_tick_list +******************* + +type: int +slots: dynamic + +Ticks that we stopped at the port in the past. + +future_stop_list +**************** + +type: int +slots: dynamic + +Stop indices that we will stop in the future. + +future_stop_tick_list +********************* + +type: int +slots: dynamic + +Ticks that we will stop in the future. + +matrices +++++++++ + +Matrices node is used to store big matrix for ports, vessels and containers. + +full_on_ports +************* + +type: int +slots: port number * port number + +Distribution of full from port to port. + +full_on_vessels +*************** + +type: int +slots: vessel number * port number + +Distribution of full from vessel to port. + +vessel_plans +************ + +type: int +slots: vessel number * port number + +Planed route info for vessels. + +How to +~~~~~~ + +How to use the matrix(s) +++++++++++++++++++++++++ + +Matrix is special that it only have one instance (index 0), and the value is saved as a flat 1 dim array, we can reshape it after querying. + +.. code-block:: python + + # assuming that we want to use full_on_ports attribute. + + tick = 0 + + # we can get the instance number of a node by calling the len method + port_number = len(env.snapshot_list["port"]) + + # this is a 1 dim numpy array + full_on_ports = env.snapshot_list["matrices"][tick::"full_on_ports"] + + # reshape it, then this is a 2 dim array that from port to port. + full_on_ports = full_on_ports.reshape(port_number, port_number) + +Citi-Bike +--------- + +Default settings for snapshot list +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Snapshot resolution: 60 + + +Max snapshot number: same as durations + +Nodes and attributes in scenario +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +station ++++++++ + +bikes +***** + +type: int +slots: 1 + +How many bikes avaiable in current station. + +shortage +******** + +type: int +slots: 1 + +Per tick state. Lack number of bikes in current station. + +trip_requirement +**************** + +type: int +slots: 1 + +Per tick states. How many requirements in current station. + +fulfillment +*********** + +type: int +slots: 1 + +How many requirement is fit in current station. + +capacity +******** + +type: int +slots: 1 + +Max number of bikes this station can take. + +id ++++ + +type: int +slots: 1 + +Id of current station. + +weekday +******* + +type: short +slots: 1 + +Weekday at current tick. + +temperature +*********** + +type: short +slots: 1 + +Temperature at current tick. + +weather +******* + +type: short +slots: 1 + +Weather at current tick. + +0: sunny, 1: rainy, 2: snowy, 3: sleet. + +holiday +******* + +type: short +slots: 1 + +If it is holidy at current tick. + +0: holiday, 1: not holiday + +extra_cost +********** + +type: int +slots: 1 + +Cost after we reach the capacity after executing action, we have to move extra bikes +to other stations. + +transfer_cost +************* + +type: int +slots: 1 + +Cost to execute action to transfer bikes to other station. + +failed_return +************* + +type: int +slots: 1 + +Per tick state. How many bikes failed to return to current station. + +min_bikes +********* + +type: int +slots: 1 + +Min bikes number in a frame. + +matrices +++++++++ + +trips_adj +********* + +type: int +slots: station number * station number + +Used to store trip requirement number between 2 stations. + + +VM-scheduling +------------- + +Default settings for snapshot list +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Snapshot resolution: 1 + + +Max snapshot number: same as durations + +Nodes and attributes in scenario +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Cluster ++++++++ + +id +*** + +type: short +slots: 1 + +Id of the cluster. + +region_id +********* + +type: short +slots: 1 + +Region is of current cluster. + +data_center_id +************** + +type: short +slots: 1 + +Data center id of current cluster. + +total_machine_num +****************** + +type: int +slots: 1 + +Total number of machines in the cluster. + +empty_machine_num +****************** + +type: int +slots: 1 + +The number of empty machines in this cluster. A empty machine means that its allocated CPU cores are 0. + +data_centers +++++++++++++ + +id +*** + +type: short +slots: 1 + +Id of current data center. + +region_id +********* + +type: short +slots: 1 + +Region id of current data center. + +zone_id +******* + +type: short +slots: 1 + +Zone id of current data center. + +total_machine_num +***************** + +type: int +slots: 1 + +Total number of machine in current data center. + +empty_machine_num +***************** + +type: int +slots: 1 + +The number of empty machines in current data center. + +pms ++++ + +Physical machine node. + +id +*** + +type: int +slots: 1 + +Id of current machine. + +cpu_cores_capacity +****************** + +type: short +slots: 1 + +Max number of cpu core can be used for current machine. + +memory_capacity +*************** + +type: short +slots: 1 + +Max number of memory can be used for current machine. + +pm_type +******* + +type: short +slots: 1 + +Type of current machine. + +cpu_cores_allocated +******************* + +type: short +slots: 1 + +How many cpu core is allocated. + +memory_allocated +**************** + +type: short +slots: 1 + +How many memory is allocated. + +cpu_utilization +*************** + +type: float +slots: 1 + +CPU utilization of current machine. + +energy_consumption +****************** + +type: float +slots: 1 + +Energy consumption of current machine. + +oversubscribable +**************** + +type: short +slots: 1 + +Physical machine type: non-oversubscribable is -1, empty: 0, oversubscribable is 1. + +region_id +********* + +type: short +slots: 1 + +Region id of current machine. + +zone_id +******* + +type: short +slots: 1 + +Zone id of current machine. + +data_center_id +************** + +type: short +slots: 1 + +Data center id of current machine. + +cluster_id +********** + +type: short +slots: 1 + +Cluster id of current machine. + +rack_id +******* + +type: short +slots: 1 + +Rack id of current machine. + +Rack +++++ + +id +*** + +type: int +slots: 1 + +Id of current rack. + +region_id +********* + +type: short +slots: 1 + +Region id of current rack. + +zone_id +******* + +type: short +slots: 1 + +Zone id of current rack. + +data_center_id +************** + +type: short +slots: 1 + +Data center id of current rack. + +cluster_id +********** + +type: short +slots: 1 + +Cluster id of current rack. + +total_machine_num +***************** + +type: int +slots: 1 + +Total number of machines on this rack. + +empty_machine_num +***************** + +type: int +slots: 1 + +Number of machines that not in use on this rack. + +regions ++++++++ + +id +*** + +type: short +slots: 1 + +Id of curent region. + +total_machine_num +***************** + +type: int +slots: 1 + +Total number of machines in this region. + +empty_machine_num +***************** + +type: int +slots: 1 + +Number of machines that not in use in this region. + +zones ++++++ + +id +*** + +type: short +slots: 1 + +Id of this zone. + +total_machine_num +***************** + +type: int +slots: 1 + +Total number of machines in this zone. + +empty_machine_num +***************** + +type: int +slots: 1 + +Number of machines that not in use in this zone. diff --git a/docs/source/key_components/rl_toolkit.rst b/docs/source/key_components/rl_toolkit.rst index be4811543..16ec97bda 100644 --- a/docs/source/key_components/rl_toolkit.rst +++ b/docs/source/key_components/rl_toolkit.rst @@ -1,121 +1,198 @@ - RL Toolkit ========== -MARO provides a full-stack abstraction for reinforcement learning (RL), which enables users to -apply predefined and customized components to various scenarios. The main abstractions include -fundamental components such as `Agent <#agent>`_\ and `Shaper <#shaper>`_\ , and training routine -controllers such as `Actor <#actor>` and `Learner <#learner>`. +MARO provides a full-stack abstraction for reinforcement learning (RL) which includes various customizable +components. In order to provide a gentle introduction for the RL toolkit, we cover the components in a top-down +manner, starting from the learning workflow. +Workflow +-------- -Agent ------ +The nice thing about MARO's RL workflows is that it is abstracted neatly from business logic, policies and learning algorithms, +making it applicable to practically any scenario that utilizes standard reinforcement learning paradigms. The workflow is +controlled by a main process that executes 2-phase learning cycles: roll-out and training (:numref:`1`). The roll-out phase +collects data from one or more environment simulators for training. There can be a single environment simulator located in the same thread as the main +loop, or multiple environment simulators running in parallel on a set of remote workers (:numref:`2`) if you need to collect large amounts of data +fast. The training phase uses the data collected during the roll-out phase to train models involved in RL policies and algorithms. +In the case of multiple large models, this phase can be made faster by having the computationally intensive gradient-related tasks +sent to a set of remote workers for parallel processing (:numref:`3`). -The Agent is the kernel abstraction of the RL formulation for a real-world problem. -Our abstraction decouples agent and its underlying model so that an agent can exist -as an RL paradigm independent of the inner workings of the models it uses to generate -actions or estimate values. For example, the actor-critic algorithm does not need to -concern itself with the structures and optimizing schemes of the actor and critic models. -This decoupling is achieved by the Core Model abstraction described below. +.. _1: +.. figure:: ../images/rl/learning_workflow.svg + :alt: Overview + :align: center + Learning Workflow -.. image:: ../images/rl/agent.svg - :target: ../images/rl/agent.svg - :alt: Agent -.. code-block:: python +.. _2: +.. figure:: ../images/rl/parallel_rollout.svg + :alt: Overview + :align: center + + Parallel Roll-out + + +.. _3: +.. figure:: ../images/rl/distributed_training.svg + :alt: Overview + :align: center + + Distributed Training + + +Environment Sampler +------------------- + +An environment sampler is an entity that contains an environment simulator and a set of policies used by agents to +interact with the environment (:numref:`4`). When creating an RL formulation for a scenario, it is necessary to define an environment +sampler class that includes these key elements: + +- how observations / snapshots of the environment are encoded into state vectors as input to the policy models. This + is sometimes referred to as state shaping in applied reinforcement learning; +- how model outputs are converted to action objects defined by the environment simulator; +- how rewards / penalties are evaluated. This is sometimes referred to as reward shaping. + +In parallel roll-out, each roll-out worker should have its own environment sampler instance. + + +.. _4: +.. figure:: ../images/rl/env_sampler.svg + :alt: Overview + :align: center + + Environment Sampler + + +Policy +------ - class AbsAgent(ABC): - def __init__(self, model: AbsCoreModel, config, experience_pool=None): - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.model = model.to(self.device) - self.config = config - self._experience_pool = experience_pool +``Policy`` is the most important concept in reinforcement learning. In MARO, the highest level abstraction of a policy +object is ``AbsPolicy``. It defines the interface ``get_actions()`` which takes a batch of states as input and returns +corresponding actions. +The action is defined by the policy itself. It could be a scalar or a vector or any other types. +Env sampler should take responsibility for parsing the action to the acceptable format before passing it to the +environment. +The simplest type of policy is ``RuleBasedPolicy`` which generates actions by pre-defined rules. ``RuleBasedPolicy`` +is mostly used in naive scenarios. However, in most cases where we need to train the policy by interacting with the +environment, we need to use ``RLPolicy``. In MARO's design, a policy cannot train itself. Instead, +polices could only be trained by :ref:`trainer` (we will introduce trainer later in this page). Therefore, in addition +to ``get_actions()``, ``RLPolicy`` also has a set of training-related interfaces, such as ``step()``, ``get_gradients()`` +and ``set_gradients()``. These interfaces will be called by trainers for training. As you may have noticed, currently +we assume policies are built upon deep learning models, so the training-related interfaces are specifically +designed for gradient descent. -Core Model ----------- -MARO provides an abstraction for the underlying models used by agents to form policies and estimate values. -The abstraction consists of ``AbsBlock`` and ``AbsCoreModel``, both of which subclass torch's nn.Module. -The ``AbsBlock`` represents the smallest structural unit of an NN-based model. For instance, the ``FullyConnectedBlock`` -provided in the toolkit is a stack of fully connected layers with features like batch normalization, -drop-out and skip connection. The ``AbsCoreModel`` is a collection of network components with -embedded optimizers and serves as an agent's "brain" by providing a unified interface to it. regardless of how many individual models it requires and how -complex the model architecture might be. +``RLPolicy`` is further divided into three types: +- ``ValueBasedPolicy``: For valued-based policies. +- ``DiscretePolicyGradient``: For gradient-based policies that generate discrete actions. +- ``ContinuousPolicyGradient``: For gradient-based policies that generate continuous actions. -As an example, the initialization of the actor-critic algorithm may look like this: +The above classes are all concrete classes. Users do not need to implement any new classes, but can directly +create a policy object by configuring parameters. Here is a simple example: .. code-block:: python - actor_stack = FullyConnectedBlock(...) - critic_stack = FullyConnectedBlock(...) - model = SimpleMultiHeadModel( - {"actor": actor_stack, "critic": critic_stack}, - optim_option={ - "actor": OptimizerOption(cls=Adam, params={"lr": 0.001}) - "critic": OptimizerOption(cls=RMSprop, params={"lr": 0.0001}) - } - ) - agent = ActorCritic("actor_critic", learning_model, config) + ValueBasedPolicy( + name="policy", + q_net=MyQNet(state_dim=128, action_num=64), + ) -Choosing an action is simply: + +For now, you may have no idea about the ``q_net`` parameter, but don't worry, we will introduce it in the next section. + +Model +----- + +The above code snippet creates a ``ValueBasedPolicy`` object. Let's pay attention to the parameter ``q_net``. +``q_net`` accepts a ``DiscreteQNet`` object, and it serves as the core part of a ``ValueBasedPolicy`` object. In +other words, ``q_net`` defines the model structure of the Q-network in the value-based policy, and further determines +the policy's behavior. ``DiscreteQNet`` is an abstract class, and ``MyQNet`` is a user-defined implementation +of ``DiscreteQNet``. It can be a simple MLP, a multi-head transformer, or any other structure that the user wants. + +MARO provides a set of abstractions of basic & commonly used PyTorch models like ``DiscereteQNet``, which enables +users to implement their own deep learning models in a handy way. They are: + +- ``DiscreteQNet``: For ``ValueBasedPolicy``. +- ``DiscretePolicyNet``: For ``DiscretePolicyGradient``. +- ``ContinuousPolicyNet``: For ``ContinuousPolicyGradient``. + +Users should choose the proper types of models according to the type of policies, and then implement their own +models by inheriting the abstract ones (just like ``MyQNet``). + +There are also some other models for training purposes. For example: + +- ``VNet``: Used in the critic part in the actor-critic algorithm. +- ``MultiQNet``: Used in the critic part in the MADDPG algorithm. +- ... + +The way to use these models is exactly the same as the way to use the policy models. + +.. _trainer: + +Algorithm (Trainer) +------- + +When introducing policies, we mentioned that policies cannot train themselves. Instead, they have to be trained +by external algorithms, which are also called trainers. +In MARO, a trainer represents an RL algorithm, such as DQN, actor-critic, +and so on. These two concepts are equivalent in the MARO context. +Trainers take interaction experiences and store them in the internal memory, and then use the experiences +in the memory to train the policies. Like ``RLPolicy``, trainers are also concrete classes, which means they could +be used by configuring parameters. Currently, we have 4 trainers (algorithms) in MARO: + +- ``DiscreteActorCriticTrainer``: Actor-critic algorithm for policies that generate discrete actions. +- ``DiscretePPOTrainer``: PPO algorithm for policies that generate discrete actions. +- ``DDPGTrainer``: DDPG algorithm for policies that generate continuous actions. +- ``DQNTrainer``: DQN algorithm for policies that generate discrete actions. +- ``DiscreteMADDPGTrainer``: MADDPG algorithm for policies that generate discrete actions. + +Each trainer has a corresponding ``Param`` class to manage all related parameters. For example, +``DiscreteActorCriticParams`` contains all parameters used in ``DiscreteActorCriticTrainer``: .. code-block:: python - model(state, task_name="actor", training=False) + @dataclass + class DiscreteActorCriticParams(TrainerParams): + get_v_critic_net_func: Callable[[], VNet] = None + reward_discount: float = 0.9 + grad_iters: int = 1 + critic_loss_cls: Callable = None + clip_ratio: float = None + lam: float = 0.9 + min_logp: Optional[float] = None -And performing one gradient step is simply: +An example of creating an actor-critic trainer: .. code-block:: python - model.learn(critic_loss + actor_loss) + DiscreteActorCriticTrainer( + name='ac', + params=DiscreteActorCriticParams( + get_v_critic_net_func=lambda: MyCriticNet(state_dim=128), + reward_discount=.0, + grad_iters=10, + critic_loss_cls=torch.nn.SmoothL1Loss, + min_logp=None, + lam=.0 + ) + ) +In order to indicate which trainer each policy is trained by, in MARO, we require that the name of the policy +start with the name of the trainer responsible for training it. For example, policy ``ac_1.policy_1`` is trained +by the trainer named ``ac_1``. Violating this provision will make MARO unable to correctly establish the +corresponding relationship between policy and trainer. -Explorer --------- +More details and examples can be found in the code base (`link`_). -MARO provides an abstraction for exploration in RL. Some RL algorithms such as DQN and DDPG require -explicit exploration governed by a set of parameters. The ``AbsExplorer`` class is designed to cater -to these needs. Simple exploration schemes, such as ``EpsilonGreedyExplorer`` for discrete action space -and ``UniformNoiseExplorer`` and ``GaussianNoiseExplorer`` for continuous action space, are provided in -the toolkit. +.. _link: https://github.com/microsoft/maro/blob/master/examples/rl/cim/policy_trainer.py -As an example, the exploration for DQN may be carried out with the aid of an ``EpsilonGreedyExplorer``: +As a summary, the relationship among policy, model, and trainer is demonstrated in :numref:`5`: -.. code-block:: python +.. _5: +.. figure:: ../images/rl/policy_model_trainer.svg + :alt: Overview + :align: center - explorer = EpsilonGreedyExplorer(num_actions=10) - greedy_action = learning_model(state, training=False).argmax(dim=1).data - exploration_action = explorer(greedy_action) - - -Tools for Training ------------------------------- - -.. image:: ../images/rl/learner_actor.svg - :target: ../images/rl/learner_actor.svg - :alt: RL Overview - -The RL toolkit provides tools that make local and distributed training easy: -* Learner, the central controller of the learning process, which consists of collecting simulation data from - remote actors and training the agents with them. The training data collection can be done in local or - distributed fashion by loading an ``Actor`` or ``ActorProxy`` instance, respectively. -* Actor, which implements the ``roll_out`` method where the agent interacts with the environment for one - episode. It consists of an environment instance and an agent (a single agent or multiple agents wrapped by - ``MultiAgentWrapper``). The class provides the as_worker() method which turns it to an event loop where roll-outs - are performed on the learner's demand. In distributed RL, there are typically many actor processes running - simultaneously to parallelize training data collection. -* Actor proxy, which also implements the ``roll_out`` method with the same signature, but manages a set of remote - actors for parallel data collection. -* Trajectory, which is primarily responsible for translating between scenario-specific information and model - input / output. It implements the following methods which are used as callbacks in the actor's roll-out loop: - * ``get_state``, which converts observations of an environment into model input. For example, the observation - may be represented by a multi-level data structure, which gets encoded by a state shaper to a one-dimensional - vector as input to a neural network. The state shaper usually goes hand in hand with the underlying policy - or value models. - * ``get_action``, which provides model output with necessary context so that it can be executed by the - environment simulator. - * ``get_reward``, which computes a reward for a given action. - * ``on_env_feedback``, which defines things to do upon getting feedback from the environment. - * ``on_finish``, which defines things to do upon completion of a roll-out episode. + Summary of policy, model, and trainer diff --git a/examples/cim/README.md b/examples/cim/README.md deleted file mode 100644 index 44b657494..000000000 --- a/examples/cim/README.md +++ /dev/null @@ -1,11 +0,0 @@ -# Container Inventory Management - -Container inventory management (CIM) is a scenario where reinforcement learning (RL) can potentially prove useful. Three algorithms are used to learn the multi-agent policy in given environments. Each algorithm has a ``config`` folder which contains ``agent_config.py`` and ``training_config.py``. The former contains parameters for the underlying models and algorithm specific hyper-parameters. The latter contains parameters for the environment and the main training loop. The file ``common.py`` contains parameters and utility functions shared by some or all of these algorithms. - -In the ``ac`` folder, , the policy is trained using the Actor-Critc algorithm in single-threaded fashion. The example can be run by simply executing ``python3 main.py``. Logs will be saved in a file named ``cim-ac.CURRENT_TIME_STAMP.log`` under the ``ac/logs`` folder, where ``CURRENT_TIME_STAMP`` is the time of executing the script. - -In the ``dqn`` folder, the policy is trained using the DQN algorithm in multi-process / distributed mode. This example can be run in three ways. -* ``python3 main.py`` or ``python3 main.py -w 0`` runs the example in multi-process mode, in which a main process spawns one learner process and a number of actor processes as specified in ``config/training_config.py``. -* ``python3 main.py -w 1`` launches the learner process only. This is for distributed training and expects a number of actor processes (as specified in ``config/training_config.py``) running on some other node(s). -* ``python3 main.py -w 2`` launches the actor process only. This is for distributed training and expects a learner process running on some other node. -Logs will be saved in a file named ``GROUP_NAME.log`` under the ``{ac_gnn, dqn}/logs`` folder, where ``GROUP_NAME`` is specified in the "group" field in ``config/training_config.py``. diff --git a/examples/cim/ac/config/__init__.py b/examples/cim/ac/config/__init__.py deleted file mode 100644 index 4492cf223..000000000 --- a/examples/cim/ac/config/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from .agent_config import agent_config -from .training_config import training_config - -__all__ = ["agent_config", "training_config"] diff --git a/examples/cim/ac/config/agent_config.py b/examples/cim/ac/config/agent_config.py deleted file mode 100644 index ecc87a80f..000000000 --- a/examples/cim/ac/config/agent_config.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from torch import nn -from torch.optim import Adam, RMSprop - -from maro.rl import OptimOption - -from examples.cim.common import common_config - -input_dim = ( - (common_config["look_back"] + 1) * - (common_config["max_ports_downstream"] + 1) * - len(common_config["port_attributes"]) + - len(common_config["vessel_attributes"]) -) - -agent_config = { - "model": { - "actor": { - "input_dim": input_dim, - "output_dim": len(common_config["action_space"]), - "hidden_dims": [256, 128, 64], - "activation": nn.Tanh, - "softmax": True, - "batch_norm": False, - "head": True - }, - "critic": { - "input_dim": input_dim, - "output_dim": 1, - "hidden_dims": [256, 128, 64], - "activation": nn.LeakyReLU, - "softmax": False, - "batch_norm": True, - "head": True - } - }, - "optimization": { - "actor": OptimOption(optim_cls=Adam, optim_params={"lr": 0.001}), - "critic": OptimOption(optim_cls=RMSprop, optim_params={"lr": 0.001}) - }, - "hyper_params": { - "reward_discount": .0, - "critic_loss_func": nn.SmoothL1Loss(), - "train_iters": 10, - "actor_loss_coefficient": 0.1, - "k": 1, - "lam": 0.0 - # "clip_ratio": 0.8 - } -} diff --git a/examples/cim/ac/config/training_config.py b/examples/cim/ac/config/training_config.py deleted file mode 100644 index c93b2c56d..000000000 --- a/examples/cim/ac/config/training_config.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -training_config = { - "env": { - "scenario": "cim", - "topology": "toy.4p_ssdd_l0.0", - "durations": 1120, - }, - "max_episode": 50 -} diff --git a/examples/cim/ac/main.py b/examples/cim/ac/main.py deleted file mode 100644 index f9bf280ec..000000000 --- a/examples/cim/ac/main.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import numpy as np - -from maro.rl import ( - Actor, ActorCritic, ActorCriticConfig, FullyConnectedBlock, MultiAgentWrapper, SimpleMultiHeadModel, - Scheduler, OnPolicyLearner -) -from maro.simulator import Env -from maro.utils import set_seeds - -from examples.cim.ac.config import agent_config, training_config -from examples.cim.common import CIMTrajectory, common_config - - -def get_ac_agent(): - actor_net = FullyConnectedBlock(**agent_config["model"]["actor"]) - critic_net = FullyConnectedBlock(**agent_config["model"]["critic"]) - ac_model = SimpleMultiHeadModel( - {"actor": actor_net, "critic": critic_net}, optim_option=agent_config["optimization"], - ) - return ActorCritic(ac_model, ActorCriticConfig(**agent_config["hyper_params"])) - - -class CIMTrajectoryForAC(CIMTrajectory): - def on_finish(self): - training_data = {} - for event, state, action in zip(self.trajectory["event"], self.trajectory["state"], self.trajectory["action"]): - agent_id = list(state.keys())[0] - data = training_data.setdefault(agent_id, {"args": [[] for _ in range(4)]}) - data["args"][0].append(state[agent_id]) # state - data["args"][1].append(action[agent_id][0]) # action - data["args"][2].append(action[agent_id][1]) # log_p - data["args"][3].append(self.get_offline_reward(event)) # reward - - for agent_id in training_data: - training_data[agent_id]["args"] = [ - np.asarray(vals, dtype=np.float32 if i == 3 else None) - for i, vals in enumerate(training_data[agent_id]["args"]) - ] - - return training_data - - -# Single-threaded launcher -if __name__ == "__main__": - set_seeds(1024) # for reproducibility - env = Env(**training_config["env"]) - agent = MultiAgentWrapper({name: get_ac_agent() for name in env.agent_idx_list}) - actor = Actor(env, agent, CIMTrajectoryForAC, trajectory_kwargs=common_config) # local actor - learner = OnPolicyLearner(actor, training_config["max_episode"]) - learner.run() diff --git a/examples/cim/common.py b/examples/cim/common.py deleted file mode 100644 index d6cea7042..000000000 --- a/examples/cim/common.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from collections import defaultdict - -import numpy as np - -from maro.rl import Trajectory -from maro.simulator.scenarios.cim.common import Action, ActionType - -common_config = { - "port_attributes": ["empty", "full", "on_shipper", "on_consignee", "booking", "shortage", "fulfillment"], - "vessel_attributes": ["empty", "full", "remaining_space"], - "action_space": list(np.linspace(-1.0, 1.0, 21)), - # Parameters for computing states - "look_back": 7, - "max_ports_downstream": 2, - # Parameters for computing actions - "finite_vessel_space": True, - "has_early_discharge": True, - # Parameters for computing rewards - "reward_time_window": 99, - "fulfillment_factor": 1.0, - "shortage_factor": 1.0, - "time_decay": 0.97 -} - - -class CIMTrajectory(Trajectory): - def __init__( - self, env, *, port_attributes, vessel_attributes, action_space, look_back, max_ports_downstream, - reward_time_window, fulfillment_factor, shortage_factor, time_decay, - finite_vessel_space=True, has_early_discharge=True - ): - super().__init__(env) - self.port_attributes = port_attributes - self.vessel_attributes = vessel_attributes - self.action_space = action_space - self.look_back = look_back - self.max_ports_downstream = max_ports_downstream - self.reward_time_window = reward_time_window - self.fulfillment_factor = fulfillment_factor - self.shortage_factor = shortage_factor - self.time_decay = time_decay - self.finite_vessel_space = finite_vessel_space - self.has_early_discharge = has_early_discharge - - def get_state(self, event): - vessel_snapshots, port_snapshots = self.env.snapshot_list["vessels"], self.env.snapshot_list["ports"] - tick, port_idx, vessel_idx = event.tick, event.port_idx, event.vessel_idx - ticks = [max(0, tick - rt) for rt in range(self.look_back - 1)] - future_port_idx_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int') - port_features = port_snapshots[ticks: [port_idx] + list(future_port_idx_list): self.port_attributes] - vessel_features = vessel_snapshots[tick: vessel_idx: self.vessel_attributes] - return {port_idx: np.concatenate((port_features, vessel_features))} - - def get_action(self, action_by_agent, event): - vessel_snapshots = self.env.snapshot_list["vessels"] - action_info = list(action_by_agent.values())[0] - model_action = action_info[0] if isinstance(action_info, tuple) else action_info - scope, tick, port, vessel = event.action_scope, event.tick, event.port_idx, event.vessel_idx - zero_action_idx = len(self.action_space) / 2 # index corresponding to value zero. - vessel_space = vessel_snapshots[tick:vessel:self.vessel_attributes][2] if self.finite_vessel_space else float("inf") - early_discharge = vessel_snapshots[tick:vessel:"early_discharge"][0] if self.has_early_discharge else 0 - percent = abs(self.action_space[model_action]) - - if model_action < zero_action_idx: - action_type = ActionType.LOAD - actual_action = min(round(percent * scope.load), vessel_space) - elif model_action > zero_action_idx: - action_type = ActionType.DISCHARGE - plan_action = percent * (scope.discharge + early_discharge) - early_discharge - actual_action = round(plan_action) if plan_action > 0 else round(percent * scope.discharge) - else: - actual_action, action_type = 0, ActionType.LOAD - - return {port: Action(vessel, port, actual_action, action_type)} - - def get_offline_reward(self, event): - port_snapshots = self.env.snapshot_list["ports"] - start_tick = event.tick + 1 - ticks = list(range(start_tick, start_tick + self.reward_time_window)) - - future_fulfillment = port_snapshots[ticks::"fulfillment"] - future_shortage = port_snapshots[ticks::"shortage"] - decay_list = [ - self.time_decay ** i for i in range(self.reward_time_window) - for _ in range(future_fulfillment.shape[0] // self.reward_time_window) - ] - - tot_fulfillment = np.dot(future_fulfillment, decay_list) - tot_shortage = np.dot(future_shortage, decay_list) - - return np.float32(self.fulfillment_factor * tot_fulfillment - self.shortage_factor * tot_shortage) - - def on_env_feedback(self, event, state_by_agent, action_by_agent, reward): - self.trajectory["event"].append(event) - self.trajectory["state"].append(state_by_agent) - self.trajectory["action"].append(action_by_agent) diff --git a/examples/cim/dqn/config/__init__.py b/examples/cim/dqn/config/__init__.py deleted file mode 100644 index 4492cf223..000000000 --- a/examples/cim/dqn/config/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from .agent_config import agent_config -from .training_config import training_config - -__all__ = ["agent_config", "training_config"] diff --git a/examples/cim/dqn/config/agent_config.py b/examples/cim/dqn/config/agent_config.py deleted file mode 100644 index 6aef90643..000000000 --- a/examples/cim/dqn/config/agent_config.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from torch import nn -from torch.optim import RMSprop - -from maro.rl import DQN, DQNConfig, FullyConnectedBlock, OptimOption, PolicyGradient, SimpleMultiHeadModel - -from examples.cim.common import common_config - -input_dim = ( - (common_config["look_back"] + 1) * - (common_config["max_ports_downstream"] + 1) * - len(common_config["port_attributes"]) + - len(common_config["vessel_attributes"]) -) - -agent_config = { - "model": { - "input_dim": input_dim, - "output_dim": len(common_config["action_space"]), # number of possible actions - "hidden_dims": [256, 128, 64], - "activation": nn.LeakyReLU, - "softmax": False, - "batch_norm": True, - "skip_connection": False, - "head": True, - "dropout_p": 0.0 - }, - "optimization": OptimOption(optim_cls=RMSprop, optim_params={"lr": 0.05}), - "hyper_params": { - "reward_discount": .0, - "loss_cls": nn.SmoothL1Loss, - "target_update_freq": 5, - "tau": 0.1, - "double": False - } -} diff --git a/examples/cim/dqn/config/training_config.py b/examples/cim/dqn/config/training_config.py deleted file mode 100644 index eb925816b..000000000 --- a/examples/cim/dqn/config/training_config.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -training_config = { - "env": { - "scenario": "cim", - "topology": "toy.4p_ssdd_l0.0", - "durations": 1120, - }, - "max_episode": 100, - "exploration": { - "parameter_names": ["epsilon"], - "split": 0.5, - "start": 0.4, - "mid": 0.32, - "end": 0.0 - }, - "training": { - "min_experiences_to_train": 1024, - "train_iter": 10, - "batch_size": 128, - "prioritized_sampling_by_loss": True - }, - "group": "cim-dqn", - "learner_update_trigger": 2, - "num_actors": 2 -} diff --git a/examples/cim/dqn/main.py b/examples/cim/dqn/main.py deleted file mode 100644 index 8cf3c6793..000000000 --- a/examples/cim/dqn/main.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import argparse -import sys - -from collections import defaultdict -from multiprocessing import Process -from os import makedirs -from os.path import dirname, join, realpath - -from maro.rl import ( - Actor, ActorProxy, DQN, DQNConfig, FullyConnectedBlock, MultiAgentWrapper, OffPolicyLearner, - SimpleMultiHeadModel, TwoPhaseLinearParameterScheduler -) -from maro.simulator import Env -from maro.utils import set_seeds - -cim_dqn_path = dirname(realpath(__file__)) -cim_example_path = dirname(cim_dqn_path) -sys.path.insert(0, cim_example_path) - -from common import CIMTrajectory, common_config -from dqn.config import agent_config, training_config - -log_dir = join(cim_dqn_path, "log") -makedirs(log_dir, exist_ok=True) - - -def get_dqn_agent(): - q_model = SimpleMultiHeadModel( - FullyConnectedBlock(**agent_config["model"]), optim_option=agent_config["optimization"] - ) - return DQN(q_model, DQNConfig(**agent_config["hyper_params"])) - - -class CIMTrajectoryForDQN(CIMTrajectory): - def on_finish(self): - exp_by_agent = defaultdict(lambda: defaultdict(list)) - for i in range(len(self.trajectory["state"]) - 1): - agent_id = list(self.trajectory["state"][i].keys())[0] - exp = exp_by_agent[agent_id] - exp["S"].append(self.trajectory["state"][i][agent_id]) - exp["A"].append(self.trajectory["action"][i][agent_id]) - exp["R"].append(self.get_offline_reward(self.trajectory["event"][i])) - exp["S_"].append(list(self.trajectory["state"][i + 1].values())[0]) - - return dict(exp_by_agent) - - -def cim_dqn_learner(): - env = Env(**training_config["env"]) - agent = MultiAgentWrapper({name: get_dqn_agent() for name in env.agent_idx_list}) - scheduler = TwoPhaseLinearParameterScheduler(training_config["max_episode"], **training_config["exploration"]) - actor = ActorProxy( - training_config["group"], training_config["num_actors"], - update_trigger=training_config["learner_update_trigger"], - log_dir=log_dir - ) - learner = OffPolicyLearner(actor, scheduler, agent, **training_config["training"], log_dir=log_dir) - learner.run() - - -def cim_dqn_actor(): - env = Env(**training_config["env"]) - agent = MultiAgentWrapper({name: get_dqn_agent() for name in env.agent_idx_list}) - actor = Actor(env, agent, CIMTrajectoryForDQN, trajectory_kwargs=common_config) - actor.as_worker(training_config["group"], log_dir=log_dir) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "-w", "--whoami", type=int, choices=[0, 1, 2], default=0, - help="Identity of this process: 0 - multi-process mode, 1 - learner, 2 - actor" - ) - - args = parser.parse_args() - if args.whoami == 0: - actor_processes = [Process(target=cim_dqn_actor) for _ in range(training_config["num_actors"])] - learner_process = Process(target=cim_dqn_learner) - - for i, actor_process in enumerate(actor_processes): - set_seeds(i) # this is to ensure that the actors explore differently. - actor_process.start() - - learner_process.start() - - for actor_process in actor_processes: - actor_process.join() - - learner_process.join() - elif args.whoami == 1: - cim_dqn_learner() - elif args.whoami == 2: - cim_dqn_actor() diff --git a/examples/cim/rl/README.md b/examples/cim/rl/README.md new file mode 100644 index 000000000..e9ecc34ec --- /dev/null +++ b/examples/cim/rl/README.md @@ -0,0 +1,9 @@ +# Container Inventory Management + +This example demonstrates the use of MARO's RL toolkit to optimize container inventory management. The scenario consists of a set of ports, each acting as a learning agent, and vessels that transfer empty containers among them. Each port must decide 1) whether to load or discharge containers when a vessel arrives and 2) how many containers to be loaded or discharged. The objective is to minimize the overall container shortage over a certain period of time. In this folder you can find: +* ``__init__.py``, the entrance of this example. You must expose a `rl_component_bundle_cls` interface in `__init__.py` (see the example file for details); +* ``config.py``, which contains general configurations for the scenario; +* ``algorithms/``, which contains configurations for the PPO, Actor-Critic, DQN and discrete-MADDPG algorithms, including network configurations; +* ``rl_componenet_bundle.py``, which defines all necessary components to run a RL job. You can go through the doc string of `RLComponentBundle` for detailed explanation, or just read `CIMBundle` to learn its basic usage. + +We recommend that you follow this example to write your own scenarios. diff --git a/examples/cim/rl/__init__.py b/examples/cim/rl/__init__.py new file mode 100644 index 000000000..695d90ede --- /dev/null +++ b/examples/cim/rl/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from .rl_component_bundle import CIMBundle as rl_component_bundle_cls + +__all__ = [ + "rl_component_bundle_cls", +] diff --git a/examples/cim/rl/algorithms/__init__.py b/examples/cim/rl/algorithms/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/cim/rl/algorithms/ac.py b/examples/cim/rl/algorithms/ac.py new file mode 100644 index 000000000..f42dffe05 --- /dev/null +++ b/examples/cim/rl/algorithms/ac.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import Dict + +import torch +from torch.optim import Adam, RMSprop + +from maro.rl.model import DiscreteACBasedNet, FullyConnected, VNet +from maro.rl.policy import DiscretePolicyGradient +from maro.rl.training.algorithms import ActorCriticTrainer, ActorCriticParams + +actor_net_conf = { + "hidden_dims": [256, 128, 64], + "activation": torch.nn.Tanh, + "softmax": True, + "batch_norm": False, + "head": True, +} +critic_net_conf = { + "hidden_dims": [256, 128, 64], + "output_dim": 1, + "activation": torch.nn.LeakyReLU, + "softmax": False, + "batch_norm": True, + "head": True, +} +actor_learning_rate = 0.001 +critic_learning_rate = 0.001 + + +class MyActorNet(DiscreteACBasedNet): + def __init__(self, state_dim: int, action_num: int) -> None: + super(MyActorNet, self).__init__(state_dim=state_dim, action_num=action_num) + self._actor = FullyConnected(input_dim=state_dim, output_dim=action_num, **actor_net_conf) + self._optim = Adam(self._actor.parameters(), lr=actor_learning_rate) + + def _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor: + return self._actor(states) + + +class MyCriticNet(VNet): + def __init__(self, state_dim: int) -> None: + super(MyCriticNet, self).__init__(state_dim=state_dim) + self._critic = FullyConnected(input_dim=state_dim, **critic_net_conf) + self._optim = RMSprop(self._critic.parameters(), lr=critic_learning_rate) + + def _get_v_values(self, states: torch.Tensor) -> torch.Tensor: + return self._critic(states).squeeze(-1) + + +def get_ac_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicyGradient: + return DiscretePolicyGradient(name=name, policy_net=MyActorNet(state_dim, action_num)) + + +def get_ac(state_dim: int, name: str) -> ActorCriticTrainer: + return ActorCriticTrainer( + name=name, + params=ActorCriticParams( + get_v_critic_net_func=lambda: MyCriticNet(state_dim), + reward_discount=.0, + grad_iters=10, + critic_loss_cls=torch.nn.SmoothL1Loss, + min_logp=None, + lam=.0, + ), + ) diff --git a/examples/cim/rl/algorithms/dqn.py b/examples/cim/rl/algorithms/dqn.py new file mode 100644 index 000000000..2194656a5 --- /dev/null +++ b/examples/cim/rl/algorithms/dqn.py @@ -0,0 +1,66 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import Dict + +import torch +from torch.optim import RMSprop + +from maro.rl.exploration import MultiLinearExplorationScheduler, epsilon_greedy +from maro.rl.model import DiscreteQNet, FullyConnected +from maro.rl.policy import ValueBasedPolicy +from maro.rl.training.algorithms import DQNTrainer, DQNParams + +q_net_conf = { + "hidden_dims": [256, 128, 64, 32], + "activation": torch.nn.LeakyReLU, + "softmax": False, + "batch_norm": True, + "skip_connection": False, + "head": True, + "dropout_p": 0.0, +} +learning_rate = 0.05 + + +class MyQNet(DiscreteQNet): + def __init__(self, state_dim: int, action_num: int) -> None: + super(MyQNet, self).__init__(state_dim=state_dim, action_num=action_num) + self._fc = FullyConnected(input_dim=state_dim, output_dim=action_num, **q_net_conf) + self._optim = RMSprop(self._fc.parameters(), lr=learning_rate) + + def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor: + return self._fc(states) + + +def get_dqn_policy(state_dim: int, action_num: int, name: str) -> ValueBasedPolicy: + return ValueBasedPolicy( + name=name, + q_net=MyQNet(state_dim, action_num), + exploration_strategy=(epsilon_greedy, {"epsilon": 0.4}), + exploration_scheduling_options=[( + "epsilon", MultiLinearExplorationScheduler, { + "splits": [(2, 0.32)], + "initial_value": 0.4, + "last_ep": 5, + "final_value": 0.0, + } + )], + warmup=100, + ) + + +def get_dqn(name: str) -> DQNTrainer: + return DQNTrainer( + name=name, + params=DQNParams( + reward_discount=.0, + update_target_every=5, + num_epochs=10, + soft_update_coef=0.1, + double=False, + replay_memory_capacity=10000, + random_overwrite=False, + batch_size=32, + ), + ) diff --git a/examples/cim/rl/algorithms/maddpg.py b/examples/cim/rl/algorithms/maddpg.py new file mode 100644 index 000000000..e422f572c --- /dev/null +++ b/examples/cim/rl/algorithms/maddpg.py @@ -0,0 +1,72 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from functools import partial +from typing import Dict, List + +import torch +from torch.optim import Adam, RMSprop + +from maro.rl.model import DiscreteACBasedNet, FullyConnected, MultiQNet +from maro.rl.policy import DiscretePolicyGradient +from maro.rl.training.algorithms import DiscreteMADDPGTrainer, DiscreteMADDPGParams + + +actor_net_conf = { + "hidden_dims": [256, 128, 64], + "activation": torch.nn.Tanh, + "softmax": True, + "batch_norm": False, + "head": True +} +critic_net_conf = { + "hidden_dims": [256, 128, 64], + "output_dim": 1, + "activation": torch.nn.LeakyReLU, + "softmax": False, + "batch_norm": True, + "head": True +} +actor_learning_rate = 0.001 +critic_learning_rate = 0.001 + + +# ##################################################################################################################### +class MyActorNet(DiscreteACBasedNet): + def __init__(self, state_dim: int, action_num: int) -> None: + super(MyActorNet, self).__init__(state_dim=state_dim, action_num=action_num) + self._actor = FullyConnected(input_dim=state_dim, output_dim=action_num, **actor_net_conf) + self._optim = Adam(self._actor.parameters(), lr=actor_learning_rate) + + def _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor: + return self._actor(states) + + +class MyMultiCriticNet(MultiQNet): + def __init__(self, state_dim: int, action_dims: List[int]) -> None: + super(MyMultiCriticNet, self).__init__(state_dim=state_dim, action_dims=action_dims) + self._critic = FullyConnected(input_dim=state_dim + sum(action_dims), **critic_net_conf) + self._optim = RMSprop(self._critic.parameters(), critic_learning_rate) + + def _get_q_values(self, states: torch.Tensor, actions: List[torch.Tensor]) -> torch.Tensor: + return self._critic(torch.cat([states] + actions, dim=1)).squeeze(-1) + + +def get_multi_critic_net(state_dim: int, action_dims: List[int]) -> MyMultiCriticNet: + return MyMultiCriticNet(state_dim, action_dims) + + +def get_maddpg_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicyGradient: + return DiscretePolicyGradient(name=name, policy_net=MyActorNet(state_dim, action_num)) + + +def get_maddpg(state_dim: int, action_dims: List[int], name: str) -> DiscreteMADDPGTrainer: + return DiscreteMADDPGTrainer( + name=name, + params=DiscreteMADDPGParams( + reward_discount=.0, + num_epoch=10, + get_q_critic_net_func=partial(get_multi_critic_net, state_dim, action_dims), + shared_critic=False + ) + ) diff --git a/examples/cim/rl/algorithms/ppo.py b/examples/cim/rl/algorithms/ppo.py new file mode 100644 index 000000000..770f68a16 --- /dev/null +++ b/examples/cim/rl/algorithms/ppo.py @@ -0,0 +1,25 @@ +import torch + +from maro.rl.policy import DiscretePolicyGradient +from maro.rl.training.algorithms import PPOParams, PPOTrainer + +from .ac import MyActorNet, MyCriticNet + + +def get_ppo_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicyGradient: + return DiscretePolicyGradient(name=name, policy_net=MyActorNet(state_dim, action_num)) + + +def get_ppo(state_dim: int, name: str) -> PPOTrainer: + return PPOTrainer( + name=name, + params=PPOParams( + get_v_critic_net_func=lambda: MyCriticNet(state_dim), + reward_discount=.0, + grad_iters=10, + critic_loss_cls=torch.nn.SmoothL1Loss, + min_logp=None, + lam=.0, + clip_ratio=0.1, + ), + ) diff --git a/examples/cim/rl/config.py b/examples/cim/rl/config.py new file mode 100644 index 000000000..9da287a94 --- /dev/null +++ b/examples/cim/rl/config.py @@ -0,0 +1,44 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +env_conf = { + "scenario": "cim", + "topology": "toy.4p_ssdd_l0.0", + "durations": 560 +} + +if env_conf["topology"].startswith("toy"): + num_agents = int(env_conf["topology"].split(".")[1][0]) +else: + num_agents = int(env_conf["topology"].split(".")[1][:2]) + +port_attributes = ["empty", "full", "on_shipper", "on_consignee", "booking", "shortage", "fulfillment"] +vessel_attributes = ["empty", "full", "remaining_space"] + +state_shaping_conf = { + "look_back": 7, + "max_ports_downstream": 2 +} + +action_shaping_conf = { + "action_space": [(i - 10) / 10 for i in range(21)], + "finite_vessel_space": True, + "has_early_discharge": True +} + +reward_shaping_conf = { + "time_window": 99, + "fulfillment_factor": 1.0, + "shortage_factor": 1.0, + "time_decay": 0.97 +} + +# obtain state dimension from a temporary env_wrapper instance +state_dim = ( + (state_shaping_conf["look_back"] + 1) * (state_shaping_conf["max_ports_downstream"] + 1) * len(port_attributes) + + len(vessel_attributes) +) + +action_num = len(action_shaping_conf["action_space"]) + +algorithm = "ppo" # ac, ppo, dqn or discrete_maddpg diff --git a/examples/cim/rl/env_sampler.py b/examples/cim/rl/env_sampler.py new file mode 100644 index 000000000..b02e09b39 --- /dev/null +++ b/examples/cim/rl/env_sampler.py @@ -0,0 +1,95 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import Any, Dict, List, Tuple, Union + +import numpy as np + +from maro.rl.rollout import AbsEnvSampler, CacheElement +from maro.simulator.scenarios.cim.common import Action, ActionType, DecisionEvent + +from .config import ( + action_shaping_conf, port_attributes, reward_shaping_conf, state_shaping_conf, + vessel_attributes, +) + + +class CIMEnvSampler(AbsEnvSampler): + def _get_global_and_agent_state_impl( + self, event: DecisionEvent, tick: int = None, + ) -> Tuple[Union[None, np.ndarray, List[object]], Dict[Any, Union[np.ndarray, List[object]]]]: + tick = self._env.tick + vessel_snapshots, port_snapshots = self._env.snapshot_list["vessels"], self._env.snapshot_list["ports"] + port_idx, vessel_idx = event.port_idx, event.vessel_idx + ticks = [max(0, tick - rt) for rt in range(state_shaping_conf["look_back"] - 1)] + future_port_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int') + state = np.concatenate([ + port_snapshots[ticks: [port_idx] + list(future_port_list): port_attributes], + vessel_snapshots[tick: vessel_idx: vessel_attributes] + ]) + return state, {port_idx: state} + + def _translate_to_env_action( + self, action_dict: Dict[Any, Union[np.ndarray, List[object]]], event: DecisionEvent, + ) -> Dict[Any, object]: + action_space = action_shaping_conf["action_space"] + finite_vsl_space = action_shaping_conf["finite_vessel_space"] + has_early_discharge = action_shaping_conf["has_early_discharge"] + + port_idx, model_action = list(action_dict.items()).pop() + + vsl_idx, action_scope = event.vessel_idx, event.action_scope + vsl_snapshots = self._env.snapshot_list["vessels"] + vsl_space = vsl_snapshots[self._env.tick:vsl_idx:vessel_attributes][2] if finite_vsl_space else float("inf") + + percent = abs(action_space[model_action[0]]) + zero_action_idx = len(action_space) / 2 # index corresponding to value zero. + if model_action < zero_action_idx: + action_type = ActionType.LOAD + actual_action = min(round(percent * action_scope.load), vsl_space) + elif model_action > zero_action_idx: + action_type = ActionType.DISCHARGE + early_discharge = vsl_snapshots[self._env.tick:vsl_idx:"early_discharge"][0] if has_early_discharge else 0 + plan_action = percent * (action_scope.discharge + early_discharge) - early_discharge + actual_action = round(plan_action) if plan_action > 0 else round(percent * action_scope.discharge) + else: + actual_action, action_type = 0, None + + return {port_idx: Action(vsl_idx, int(port_idx), actual_action, action_type)} + + def _get_reward(self, env_action_dict: Dict[Any, object], event: DecisionEvent, tick: int) -> Dict[Any, float]: + start_tick = tick + 1 + ticks = list(range(start_tick, start_tick + reward_shaping_conf["time_window"])) + + # Get the ports that took actions at the given tick + ports = [int(port) for port in list(env_action_dict.keys())] + port_snapshots = self._env.snapshot_list["ports"] + future_fulfillment = port_snapshots[ticks:ports:"fulfillment"].reshape(len(ticks), -1) + future_shortage = port_snapshots[ticks:ports:"shortage"].reshape(len(ticks), -1) + + decay_list = [reward_shaping_conf["time_decay"] ** i for i in range(reward_shaping_conf["time_window"])] + rewards = np.float32( + reward_shaping_conf["fulfillment_factor"] * np.dot(future_fulfillment.T, decay_list) + - reward_shaping_conf["shortage_factor"] * np.dot(future_shortage.T, decay_list) + ) + return {agent_id: reward for agent_id, reward in zip(ports, rewards)} + + def _post_step(self, cache_element: CacheElement) -> None: + self._info["env_metric"] = self._env.metrics + + def _post_eval_step(self, cache_element: CacheElement) -> None: + self._post_step(cache_element) + + def post_collect(self, info_list: list, ep: int) -> None: + # print the env metric from each rollout worker + for info in info_list: + print(f"env summary (episode {ep}): {info['env_metric']}") + + # print the average env metric + if len(info_list) > 1: + metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list) + avg_metric = {key: sum(info["env_metric"][key] for info in info_list) / num_envs for key in metric_keys} + print(f"average env summary (episode {ep}): {avg_metric}") + + def post_evaluate(self, info_list: list, ep: int) -> None: + self.post_collect(info_list, ep) diff --git a/examples/cim/rl/rl_component_bundle.py b/examples/cim/rl/rl_component_bundle.py new file mode 100644 index 000000000..3fe5aeaa7 --- /dev/null +++ b/examples/cim/rl/rl_component_bundle.py @@ -0,0 +1,84 @@ +from functools import partial +from typing import Any, Callable, Dict, Optional + +from examples.cim.rl.config import action_num, algorithm, env_conf, num_agents, reward_shaping_conf, state_dim +from examples.cim.rl.env_sampler import CIMEnvSampler +from maro.rl.policy import AbsPolicy +from maro.rl.rl_component.rl_component_bundle import RLComponentBundle +from maro.rl.rollout import AbsEnvSampler +from maro.rl.training import AbsTrainer + +from .algorithms.ac import get_ac_policy +from .algorithms.dqn import get_dqn_policy +from .algorithms.maddpg import get_maddpg_policy +from .algorithms.ppo import get_ppo_policy +from .algorithms.ac import get_ac +from .algorithms.ppo import get_ppo +from .algorithms.dqn import get_dqn +from .algorithms.maddpg import get_maddpg + + +class CIMBundle(RLComponentBundle): + def get_env_config(self) -> dict: + return env_conf + + def get_test_env_config(self) -> Optional[dict]: + return None + + def get_env_sampler(self) -> AbsEnvSampler: + return CIMEnvSampler(self.env, self.test_env, reward_eval_delay=reward_shaping_conf["time_window"]) + + def get_agent2policy(self) -> Dict[Any, str]: + return {agent: f"{algorithm}_{agent}.policy"for agent in self.env.agent_idx_list} + + def get_policy_creator(self) -> Dict[str, Callable[[], AbsPolicy]]: + if algorithm == "ac": + policy_creator = { + f"{algorithm}_{i}.policy": partial(get_ac_policy, state_dim, action_num, f"{algorithm}_{i}.policy") + for i in range(num_agents) + } + elif algorithm == "ppo": + policy_creator = { + f"{algorithm}_{i}.policy": partial(get_ppo_policy, state_dim, action_num, f"{algorithm}_{i}.policy") + for i in range(num_agents) + } + elif algorithm == "dqn": + policy_creator = { + f"{algorithm}_{i}.policy": partial(get_dqn_policy, state_dim, action_num, f"{algorithm}_{i}.policy") + for i in range(num_agents) + } + elif algorithm == "discrete_maddpg": + policy_creator = { + f"{algorithm}_{i}.policy": partial(get_maddpg_policy, state_dim, action_num, f"{algorithm}_{i}.policy") + for i in range(num_agents) + } + else: + raise ValueError(f"Unsupported algorithm: {algorithm}") + + return policy_creator + + def get_trainer_creator(self) -> Dict[str, Callable[[], AbsTrainer]]: + if algorithm == "ac": + trainer_creator = { + f"{algorithm}_{i}": partial(get_ac, state_dim, f"{algorithm}_{i}") + for i in range(num_agents) + } + elif algorithm == "ppo": + trainer_creator = { + f"{algorithm}_{i}": partial(get_ppo, state_dim, f"{algorithm}_{i}") + for i in range(num_agents) + } + elif algorithm == "dqn": + trainer_creator = { + f"{algorithm}_{i}": partial(get_dqn, f"{algorithm}_{i}") + for i in range(num_agents) + } + elif algorithm == "discrete_maddpg": + trainer_creator = { + f"{algorithm}_{i}": partial(get_maddpg, state_dim, [1], f"{algorithm}_{i}") + for i in range(num_agents) + } + else: + raise ValueError(f"Unsupported algorithm: {algorithm}") + + return trainer_creator diff --git a/examples/citi_bike/greedy/__init__.py b/examples/citi_bike/greedy/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/citi_bike/online_lp/README.md b/examples/citi_bike/online_lp/README.md index 3cedc729d..9766a460c 100644 --- a/examples/citi_bike/online_lp/README.md +++ b/examples/citi_bike/online_lp/README.md @@ -99,7 +99,7 @@ demand is 34 (at a specific station, during a time interval of 20 minutes), the corresponding demand distribution shows that demand exceeding 10 bikes per time interval (20 minutes) is only 2%. -![Demand Distribution Between Tick 2400 ~ Tick 2519](./LogDemand.ny201910.2400.png) +![Demand Distribution Between Tick 2400 ~ Tick 2519](LogDemand.ny201910.2400.png) Besides, we can also find that the percentage of forecasting results that differ to the data extracted from trip log is not low. To dive deeper in the practical @@ -110,9 +110,9 @@ show the distribution of the forecasting difference to the trip log. One for the interval with the *Max Diff* (16:00-18:00), one for the interval with the highest percentage of *Diff > 5* (10:00-12:00). -![Demand Distribution Between Tick 2400 ~ Tick 2519](./DemandDiff.ny201910.2400.png) +![Demand Distribution Between Tick 2400 ~ Tick 2519](DemandDiff.ny201910.2400.png) -![Demand Distribution Between Tick 2040 ~ Tick 2159](./DemandDiff.ny201910.2040.png) +![Demand Distribution Between Tick 2040 ~ Tick 2159](DemandDiff.ny201910.2040.png) Maybe due to the *sparse* and *small* trip demand, and the *small* difference between the forecasting results and data extracted from the trip log data, the diff --git a/examples/citi_bike/online_lp/__init__.py b/examples/citi_bike/online_lp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/citi_bike/online_lp/launcher.py b/examples/citi_bike/online_lp/launcher.py index 13e4473a4..dfc776617 100644 --- a/examples/citi_bike/online_lp/launcher.py +++ b/examples/citi_bike/online_lp/launcher.py @@ -75,10 +75,10 @@ def _record_history(self, env_tick: int, finished_events: List[AbsEvent]): event_type = finished_events[self._next_event_idx].event_type if event_type == CitiBikeEvents.RequireBike: # TODO: Replace it with a pre-defined PayLoad. - payload = finished_events[self._next_event_idx].payload + payload = finished_events[self._next_event_idx].body demand_history[interval_idx, payload.src_station] += 1 elif event_type == CitiBikeEvents.ReturnBike: - payload: BikeReturnPayload = finished_events[self._next_event_idx].payload + payload: BikeReturnPayload = finished_events[self._next_event_idx].body supply_history[interval_idx, payload.to_station_idx] += payload.number # Update the index to the finished event that has not been processed. @@ -129,7 +129,7 @@ def __peep_at_the_future(self, env_tick: int): # Process to get the future supply from Pending Events. for pending_event in ENV.get_pending_events(tick=tick): if pending_event.event_type == CitiBikeEvents.ReturnBike: - payload: BikeReturnPayload = pending_event.payload + payload: BikeReturnPayload = pending_event.body supply[interval_idx, payload.to_station_idx] += payload.number return demand, supply diff --git a/examples/proxy/broadcast.py b/examples/proxy/broadcast.py index 417737cc3..482391b6b 100644 --- a/examples/proxy/broadcast.py +++ b/examples/proxy/broadcast.py @@ -21,13 +21,13 @@ def worker(group_name): print(f"{proxy.name}'s counter is {counter}.") # Nonrecurring receive the message from the proxy. - for msg in proxy.receive(is_continuous=False): - print(f"{proxy.name} receive message from {msg.source}.") + msg = proxy.receive_once() + print(f"{proxy.name} received message from {msg.source}.") - if msg.tag == "INC": - counter += 1 - print(f"{proxy.name} receive INC request, {proxy.name}'s count is {counter}.") - proxy.reply(message=msg, tag="done") + if msg.tag == "INC": + counter += 1 + print(f"{proxy.name} receive INC request, {proxy.name}'s count is {counter}.") + proxy.reply(message=msg, tag="done") def master(group_name: str, worker_num: int, is_immediate: bool = False): diff --git a/examples/proxy/scatter.py b/examples/proxy/scatter.py index 5d2cff8f7..36cbb295c 100644 --- a/examples/proxy/scatter.py +++ b/examples/proxy/scatter.py @@ -21,12 +21,12 @@ def summation_worker(group_name): expected_peers={"master": 1}) # Nonrecurring receive the message from the proxy. - for msg in proxy.receive(is_continuous=False): - print(f"{proxy.name} receive message from {msg.source}. the payload is {msg.payload}.") + msg = proxy.receive_once() + print(f"{proxy.name} received message from {msg.source}. the payload is {msg.body}.") - if msg.tag == "job": - replied_payload = sum(msg.payload) - proxy.reply(message=msg, tag="sum", payload=replied_payload) + if msg.tag == "job": + replied_payload = sum(msg.body) + proxy.reply(message=msg, tag="sum", body=replied_payload) def multiplication_worker(group_name): @@ -41,12 +41,12 @@ def multiplication_worker(group_name): expected_peers={"master": 1}) # Nonrecurring receive the message from the proxy. - for msg in proxy.receive(is_continuous=False): - print(f"{proxy.name} receive message from {msg.source}. the payload is {msg.payload}.") + msg = proxy.receive_once() + print(f"{proxy.name} receive message from {msg.source}. the payload is {msg.body}.") - if msg.tag == "job": - replied_payload = np.prod(msg.payload) - proxy.reply(message=msg, tag="multiply", payload=replied_payload) + if msg.tag == "job": + replied_payload = np.prod(msg.body) + proxy.reply(message=msg, tag="multiply", body=replied_payload) def master(group_name: str, sum_worker_number: int, multiply_worker_number: int, is_immediate: bool = False): @@ -73,13 +73,13 @@ def master(group_name: str, sum_worker_number: int, multiply_worker_number: int, # Assign sum tasks for summation workers. destination_payload_list = [] - for idx, peer in enumerate(proxy.peers_name["sum_worker"]): - data_length_per_peer = int(len(sum_list) / len(proxy.peers_name["sum_worker"])) + for idx, peer in enumerate(proxy.peers["sum_worker"]): + data_length_per_peer = int(len(sum_list) / len(proxy.peers["sum_worker"])) destination_payload_list.append((peer, sum_list[idx * data_length_per_peer:(idx + 1) * data_length_per_peer])) # Assign multiply tasks for multiplication workers. - for idx, peer in enumerate(proxy.peers_name["multiply_worker"]): - data_length_per_peer = int(len(multiple_list) / len(proxy.peers_name["multiply_worker"])) + for idx, peer in enumerate(proxy.peers["multiply_worker"]): + data_length_per_peer = int(len(multiple_list) / len(proxy.peers["multiply_worker"])) destination_payload_list.append( (peer, multiple_list[idx * data_length_per_peer:(idx + 1) * data_length_per_peer])) @@ -98,11 +98,11 @@ def master(group_name: str, sum_worker_number: int, multiply_worker_number: int, sum_result, multiply_result = 0, 1 for msg in replied_msgs: if msg.tag == "sum": - print(f"{proxy.name} receive message from {msg.source} with the sum result {msg.payload}.") - sum_result += msg.payload + print(f"{proxy.name} receive message from {msg.source} with the sum result {msg.body}.") + sum_result += msg.body elif msg.tag == "multiply": - print(f"{proxy.name} receive message from {msg.source} with the multiply result {msg.payload}.") - multiply_result *= msg.payload + print(f"{proxy.name} receive message from {msg.source} with the multiply result {msg.body}.") + multiply_result *= msg.body # Check task result correction. assert(sum(sum_list) == sum_result) diff --git a/examples/proxy/send.py b/examples/proxy/send.py index 73aa45dba..e35c92a1e 100644 --- a/examples/proxy/send.py +++ b/examples/proxy/send.py @@ -21,12 +21,12 @@ def worker(group_name): expected_peers={"master": 1}) # Nonrecurring receive the message from the proxy. - for msg in proxy.receive(is_continuous=False): - print(f"{proxy.name} receive message from {msg.source}. the payload is {msg.payload}.") + msg = proxy.receive_once() + print(f"{proxy.name} received message from {msg.source}. the payload is {msg.body}.") - if msg.tag == "sum": - replied_payload = sum(msg.payload) - proxy.reply(message=msg, tag="sum", payload=replied_payload) + if msg.tag == "sum": + replied_payload = sum(msg.body) + proxy.reply(message=msg, tag="sum", body=replied_payload) def master(group_name: str, is_immediate: bool = False): @@ -47,11 +47,11 @@ def master(group_name: str, is_immediate: bool = False): random_integer_list = np.random.randint(0, 100, 5) print(f"generate random integer list: {random_integer_list}.") - for peer in proxy.peers_name["worker"]: + for peer in proxy.peers["worker"]: message = SessionMessage(tag="sum", source=proxy.name, destination=peer, - payload=random_integer_list, + body=random_integer_list, session_type=SessionType.TASK) if is_immediate: session_id = proxy.isend(message) @@ -61,7 +61,7 @@ def master(group_name: str, is_immediate: bool = False): replied_msgs = proxy.send(message, timeout=-1) for msg in replied_msgs: - print(f"{proxy.name} receive {msg.source}, replied payload is {msg.payload}.") + print(f"{proxy.name} receive {msg.source}, replied payload is {msg.body}.") if __name__ == "__main__": diff --git a/examples/rl/README.md b/examples/rl/README.md new file mode 100644 index 000000000..ca3a3807e --- /dev/null +++ b/examples/rl/README.md @@ -0,0 +1,19 @@ +# Reinforcement Learning (RL) Examples + +This folder contains scenarios that employ reinforcement learning. MARO's RL toolkit provides scenario-agnostic workflows to run a variety of scenarios in single-thread, multi-process or distributed modes. + +## How to Run + +The entrance of a RL workflow is a YAML config file. For readers' convenience, we call this config file `config.yml` in the rest part of this doc. `config.yml` specifies the path of all necessary resources, definitions, and configurations to run the job. MARO provides a comprehensive template of the config file with detailed explanations (`maro/maro/rl/workflows/config/template.yml`). Meanwhile, MARO also provides several simple examples of `config.yml` under the current folder. + +There are two ways to start the RL job: +- If you only need to have a quick look and try to start an out-of-box workflow, just run `python .\examples\rl\run_rl_example.py PATH_TO_CONFIG_YAML`. For example, `python .\examples\rl\run_rl_example.py .\examples\rl\cim.yml` will run the complete example RL training workflow of CIM scenario. If you only want to run the evaluation workflow, you could start the job with `--evaluate_only`. +- (**Require install MARO from source**) You could also start the job through MARO CLI. Use the command `maro local run [-c] path/to/your/config` to run in containerized (with `-c`) or non-containerized (without `-c`) environments. Similar, you could add `--evaluate_only` if you only need to run the evaluation workflow. + +## Create Your Own Scenarios + +You can create your own scenarios by supplying the necessary ingredients without worrying about putting them together in a workflow. It is necessary to create an ``__init__.py`` under your scenario folder (so that it can be treated as a package) and expose a `rl_component_bundle_cls` interface. The MARO's RL workflow will use this interface to create a `RLComponentBundle` instance and start the RL workflow based on it. a `RLComponentBundle` instance defines all necessary components to run a RL job. You can go through the doc string of `RLComponentBundle` for detailed explanation, or just read one of the examples to learn its basic usage. + +## Example + +For a complete example, please check `examples/cim/rl`. diff --git a/examples/rl/cim.yml b/examples/rl/cim.yml new file mode 100644 index 000000000..bb99d164a --- /dev/null +++ b/examples/rl/cim.yml @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# Example RL config file for CIM scenario. +# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations. + +# Run this workflow by executing one of the following commands: +# - python .\examples\rl\run_rl_example.py .\examples\rl\cim.yml +# - (Requires installing MARO from source) maro local run .\examples\rl\cim.yml + +job: cim_rl_workflow +scenario_path: "examples/cim/rl" +log_path: "log/rl_job/cim.txt" +main: + num_episodes: 30 # Number of episodes to run. Each episode is one cycle of roll-out and training. + num_steps: null + eval_schedule: 5 + logging: + stdout: INFO + file: DEBUG +rollout: + logging: + stdout: INFO + file: DEBUG +training: + mode: simple + load_path: null + load_episode: null + checkpointing: + path: "checkpoint/rl_job/cim" + interval: 5 + logging: + stdout: INFO + file: DEBUG diff --git a/examples/rl/run_rl_example.py b/examples/rl/run_rl_example.py new file mode 100644 index 000000000..e50c4f7f1 --- /dev/null +++ b/examples/rl/run_rl_example.py @@ -0,0 +1,15 @@ +import argparse + +from maro.cli.local.commands import run + + +def get_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("conf_path", help='Path of the job deployment') + parser.add_argument("--evaluate_only", action="store_true", help="Only run evaluation part of the workflow") + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + run(conf_path=args.conf_path, containerize=False, evaluate_only=args.evaluate_only) diff --git a/examples/rl/vm_scheduling.yml b/examples/rl/vm_scheduling.yml new file mode 100644 index 000000000..5f4f20747 --- /dev/null +++ b/examples/rl/vm_scheduling.yml @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# Example RL config file for VM scheduling scenario. +# Please refer to `maro/rl/workflows/config/template.yml` for the complete template and detailed explanations. + +# Run this workflow by executing one of the following commands: +# - python .\examples\rl\run_rl_example.py .\examples\rl\vm_scheduling.yml +# - (Requires installing MARO from source) maro local run .\examples\rl\vm_scheduling.yml + +job: vm_scheduling_rl_workflow +scenario_path: "examples/vm_scheduling/rl" +log_path: "log/rl_job/vm_scheduling.txt" +main: + num_episodes: 30 # Number of episodes to run. Each episode is one cycle of roll-out and training. + num_steps: null + eval_schedule: 5 + logging: + stdout: INFO + file: DEBUG +rollout: + logging: + stdout: INFO + file: DEBUG +training: + mode: simple + load_path: null + load_episode: null + checkpointing: + path: "checkpoint/rl_job/vm_scheduling" + interval: 5 + logging: + stdout: INFO + file: DEBUG diff --git a/examples/vm_scheduling/offline_lp/__init__.py b/examples/vm_scheduling/offline_lp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/vm_scheduling/offline_lp/launcher.py b/examples/vm_scheduling/offline_lp/launcher.py index 5fd264c73..2e5f21870 100644 --- a/examples/vm_scheduling/offline_lp/launcher.py +++ b/examples/vm_scheduling/offline_lp/launcher.py @@ -22,14 +22,11 @@ config = convert_dottable(raw_config) LOG_PATH = os.path.join(FILE_PATH, "log", config.experiment_name) -if not os.path.exists(LOG_PATH): - os.makedirs(LOG_PATH) -simulation_logger = Logger(tag="simulation", format_=LogFormat.none, dump_folder=LOG_PATH, dump_mode="w", auto_timestamp=False) -ilp_logger = Logger(tag="ilp", format_=LogFormat.none, dump_folder=LOG_PATH, dump_mode="w", auto_timestamp=False) +simulation_logger = Logger(tag="simulation", format_=LogFormat.none, dump_path=LOG_PATH, dump_mode="w") +ilp_logger = Logger(tag="ilp", format_=LogFormat.none, dump_path=LOG_PATH, dump_mode="w") if __name__ == "__main__": start_time = timeit.default_timer() - env = Env( scenario=config.env.scenario, topology=config.env.topology, diff --git a/examples/vm_scheduling/rl/README.md b/examples/vm_scheduling/rl/README.md new file mode 100644 index 000000000..950c41b16 --- /dev/null +++ b/examples/vm_scheduling/rl/README.md @@ -0,0 +1,24 @@ +# Virtual Machine Scheduling + +A virtual machine (VM) scheduler is a cloud computing service component responsible for providing compute resources to satisfy user demands. A good resource allocation policy should aim to optimize several metrics at the same time, such as user wait time, profit, energy consumption and physical machine (PM) overload. Many commercial cloud providers use rule-based policies. Alternatively, the policy can also be optimized using reinforcement learning (RL) techniques, which involves simulating with historical data. This example demonstrates how DQN and Actor-Critic algorithms can be applied to this scenario. In this folder, you can find: + +* ``__init__.py``, the entrance of this example. You must expose a `rl_component_bundle_cls` interface in `__init__.py` (see the example file for details); +* ``config.py``, which contains general configurations for the scenario; +* ``algorithms/``, which contains configurations for the algorithms, including network configurations; +* ``rl_componenet_bundle.py``, which defines all necessary components to run a RL job. You can go through the doc string of `RLComponentBundle` for detailed explanation, or just read `VMBundle` to learn its basic usage. + +We recommend that you follow this example to write your own scenarios. + + +# Some Comments About the Results + +This example is meant to serve as a demonstration of using MARO's RL toolkit in a real-life scenario. In fact, we have yet to find a configuration that makes the policy learned by either DQN or Actor-Critic perform reasonably well in our experimental settings. + +For reference, the best results have been achieved by the ``Best Fit`` algorithm (see ``examples/vm_scheduling/rule_based_algorithm/best_fit.py`` for details). The over-subscription rate is 115% in the over-subscription settings. + +|Topology | PM Setting | Time Spent(s) | Total VM Requests |Successful Allocation| Energy Consumption| Total Oversubscriptions | Total Overload PMs +|:----:|-----|:--------:|:---:|:-------:|:----:|:---:|:---:| +|10k| 100 PMs, 32 Cores, 128 GB | 104.98|10,000| 10,000| 2,399,610 | 0 | 0| +|10k.oversubscription| 100 PMs, 32 Cores, 128 GB| 101.00 |10,000 |10,000| 2,386,371| 279,331 | 0| +|336k| 880 PMs, 16 Cores, 112 GB | 7,896.37 |335,985| 109,249 |26,425,878 | 0 | 0 | +|336k.oversubscription| 880 PMs, 16 Cores, 112 GB | 7,903.33| 335,985| 115,008 | 27,440,946 | 3,868,475 | 0 diff --git a/examples/vm_scheduling/rl/__init__.py b/examples/vm_scheduling/rl/__init__.py new file mode 100644 index 000000000..44e5138a2 --- /dev/null +++ b/examples/vm_scheduling/rl/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from .rl_component_bundle import VMBundle as rl_component_bundle_cls + +__all__ = [ + "rl_component_bundle_cls", +] diff --git a/examples/vm_scheduling/rl/algorithms/__init__.py b/examples/vm_scheduling/rl/algorithms/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/vm_scheduling/rl/algorithms/ac.py b/examples/vm_scheduling/rl/algorithms/ac.py new file mode 100644 index 000000000..787730ea2 --- /dev/null +++ b/examples/vm_scheduling/rl/algorithms/ac.py @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import Dict + +import torch +from torch.optim import Adam, SGD + +from maro.rl.model import DiscreteACBasedNet, FullyConnected, VNet +from maro.rl.policy import DiscretePolicyGradient +from maro.rl.training.algorithms import ActorCriticTrainer, ActorCriticParams + + +actor_net_conf = { + "hidden_dims": [64, 32, 32], + "activation": torch.nn.LeakyReLU, + "softmax": True, + "batch_norm": False, + "head": True, +} + +critic_net_conf = { + "hidden_dims": [256, 128, 64], + "activation": torch.nn.LeakyReLU, + "softmax": False, + "batch_norm": False, + "head": True, +} + +actor_learning_rate = 0.0001 +critic_learning_rate = 0.001 + + +class MyActorNet(DiscreteACBasedNet): + def __init__(self, state_dim: int, action_num: int, num_features: int) -> None: + super(MyActorNet, self).__init__(state_dim=state_dim, action_num=action_num) + self._num_features = num_features + self._actor = FullyConnected(input_dim=num_features, output_dim=action_num, **actor_net_conf) + self._optim = Adam(self._actor.parameters(), lr=actor_learning_rate) + + def _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor: + features, masks = states[:, :self._num_features], states[:, self._num_features:] + masks += 1e-8 # this is to prevent zero probability and infinite logP. + return self._actor(features) * masks + + +class MyCriticNet(VNet): + def __init__(self, state_dim: int, num_features: int) -> None: + super(MyCriticNet, self).__init__(state_dim=state_dim) + self._num_features = num_features + self._critic = FullyConnected(input_dim=num_features, output_dim=1, **critic_net_conf) + self._optim = SGD(self._critic.parameters(), lr=critic_learning_rate) + + def _get_v_values(self, states: torch.Tensor) -> torch.Tensor: + features, masks = states[:, :self._num_features], states[:, self._num_features:] + masks += 1e-8 # this is to prevent zero probability and infinite logP. + return self._critic(features).squeeze(-1) + + +def get_ac_policy(state_dim: int, action_num: int, num_features: int, name: str) -> DiscretePolicyGradient: + return DiscretePolicyGradient(name=name, policy_net=MyActorNet(state_dim, action_num, num_features)) + + +def get_ac(state_dim: int, num_features: int, name: str) -> ActorCriticTrainer: + return ActorCriticTrainer( + name=name, + params=ActorCriticParams( + get_v_critic_net_func=lambda: MyCriticNet(state_dim, num_features), + reward_discount=0.9, + grad_iters=100, + critic_loss_cls=torch.nn.MSELoss, + min_logp=-20, + lam=.0, + ), + ) diff --git a/examples/vm_scheduling/rl/algorithms/dqn.py b/examples/vm_scheduling/rl/algorithms/dqn.py new file mode 100644 index 000000000..12a8adc98 --- /dev/null +++ b/examples/vm_scheduling/rl/algorithms/dqn.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import numpy as np +import torch +from torch.optim import SGD +from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts + +from maro.rl.exploration import MultiLinearExplorationScheduler +from maro.rl.model import DiscreteQNet, FullyConnected +from maro.rl.policy import ValueBasedPolicy +from maro.rl.training.algorithms import DQNParams, DQNTrainer + +q_net_conf = { + "hidden_dims": [64, 128, 256], + "activation": torch.nn.LeakyReLU, + "softmax": False, + "batch_norm": False, + "skip_connection": False, + "head": True, + "dropout_p": 0.0, +} +q_net_learning_rate = 0.0005 +q_net_lr_scheduler_params = {"T_0": 500, "T_mult": 2} + + +class MyQNet(DiscreteQNet): + def __init__(self, state_dim: int, action_num: int, num_features: int) -> None: + super(MyQNet, self).__init__(state_dim=state_dim, action_num=action_num) + self._num_features = num_features + self._fc = FullyConnected(input_dim=num_features, output_dim=action_num, **q_net_conf) + self._optim = SGD(self._fc.parameters(), lr=q_net_learning_rate) + self._lr_scheduler = CosineAnnealingWarmRestarts(self._optim, **q_net_lr_scheduler_params) + + def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor: + masks = states[:, self._num_features:] + q_for_all_actions = self._fc(states[:, :self._num_features]) + return q_for_all_actions + (masks - 1) * 1e8 + + +class MaskedEpsGreedy: + def __init__(self, state_dim: int, num_features: int) -> None: + self._state_dim = state_dim + self._num_features = num_features + + def __call__(self, states, actions, num_actions, *, epsilon): + masks = states[:, self._num_features:] + return np.array([ + action if np.random.random() > epsilon else np.random.choice(np.where(mask == 1)[0]) + for action, mask in zip(actions, masks) + ]) + + +def get_dqn_policy(state_dim: int, action_num: int, num_features: int, name: str) -> ValueBasedPolicy: + return ValueBasedPolicy( + name=name, + q_net=MyQNet(state_dim, action_num, num_features), + exploration_strategy=(MaskedEpsGreedy(state_dim, num_features), {"epsilon": 0.4}), + exploration_scheduling_options=[( + "epsilon", MultiLinearExplorationScheduler, { + "splits": [(100, 0.32)], + "initial_value": 0.4, + "last_ep": 400, + "final_value": 0.0, + } + )], + warmup=100, + ) + + +def get_dqn(name: str) -> DQNTrainer: + return DQNTrainer( + name=name, + params=DQNParams( + reward_discount=0.9, + update_target_every=5, + num_epochs=100, + soft_update_coef=0.1, + double=False, + replay_memory_capacity=10000, + random_overwrite=False, + batch_size=32, + data_parallelism=2, + ), + ) diff --git a/examples/vm_scheduling/rl/config.py b/examples/vm_scheduling/rl/config.py new file mode 100644 index 000000000..71cf2eaa5 --- /dev/null +++ b/examples/vm_scheduling/rl/config.py @@ -0,0 +1,44 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from maro.simulator import Env + + +env_conf = { + "scenario": "vm_scheduling", + "topology": "azure.2019.10k", + "start_tick": 0, + "durations": 300, # 8638 + "snapshot_resolution": 1, +} + +num_pms = Env(**env_conf).business_engine.pm_amount +pm_window_size = 1 +num_features = 2 * num_pms * pm_window_size + 4 +state_dim = num_features + num_pms + 1 + +pm_attributes = ["cpu_cores_capacity", "memory_capacity", "cpu_cores_allocated", "memory_allocated"] +# vm_attributes = ["cpu_cores_requirement", "memory_requirement", "lifetime", "remain_time", "total_income"] + + +reward_shaping_conf = { + "alpha": 0.0, + "beta": 1.0, +} +seed = 666 + +test_env_conf = { + "scenario": "vm_scheduling", + "topology": "azure.2019.10k.oversubscription", + "start_tick": 0, + "durations": 300, + "snapshot_resolution": 1, +} +test_reward_shaping_conf = { + "alpha": 0.0, + "beta": 1.0, +} + +test_seed = 1024 + +algorithm = "ac" # "dqn" or "ac" diff --git a/examples/vm_scheduling/rl/env_sampler.py b/examples/vm_scheduling/rl/env_sampler.py new file mode 100644 index 000000000..1913e3adb --- /dev/null +++ b/examples/vm_scheduling/rl/env_sampler.py @@ -0,0 +1,200 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import time +from collections import defaultdict +from os import makedirs +from os.path import dirname, join, realpath +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +from matplotlib import pyplot as plt + +from maro.rl.rollout import AbsEnvSampler, CacheElement +from maro.simulator import Env +from maro.simulator.scenarios.vm_scheduling import AllocateAction, DecisionPayload, PostponeAction + +from .config import ( + num_features, pm_attributes, pm_window_size, reward_shaping_conf, seed, test_reward_shaping_conf, test_seed, +) + +timestamp = str(time.time()) +plt_path = join(dirname(realpath(__file__)), "plots", timestamp) +makedirs(plt_path, exist_ok=True) + + +class VMEnvSampler(AbsEnvSampler): + def __init__(self, learn_env: Env, test_env: Env) -> None: + super(VMEnvSampler, self).__init__(learn_env, test_env) + + self._learn_env.set_seed(seed) + self._test_env.set_seed(test_seed) + + # adjust the ratio of the success allocation and the total income when computing the reward + self.num_pms = self._learn_env.business_engine._pm_amount # the number of pms + self._durations = self._learn_env.business_engine._max_tick + self._pm_state_history = np.zeros((pm_window_size - 1, self.num_pms, 2)) + self._legal_pm_mask = None + + def _get_global_and_agent_state_impl( + self, event: DecisionPayload, tick: int = None, + ) -> Tuple[Union[None, np.ndarray, List[object]], Dict[Any, Union[np.ndarray, List[object]]]]: + pm_state, vm_state = self._get_pm_state(), self._get_vm_state(event) + # get the legal number of PM. + legal_pm_mask = np.zeros(self.num_pms + 1) + if len(event.valid_pms) <= 0: + # no pm available + legal_pm_mask[self.num_pms] = 1 + else: + legal_pm_mask[self.num_pms] = 1 + remain_cpu_dict = dict() + for pm in event.valid_pms: + # If two pms have the same remaining cpu, choose the one with the smaller id + if pm_state[-1, pm, 0] not in remain_cpu_dict: + remain_cpu_dict[pm_state[-1, pm, 0]] = 1 + legal_pm_mask[pm] = 1 + else: + legal_pm_mask[pm] = 0 + + self._legal_pm_mask = legal_pm_mask + state = np.concatenate((pm_state.flatten(), vm_state.flatten(), legal_pm_mask)).astype(np.float32) + return None, {"AGENT": state} + + def _translate_to_env_action( + self, action_dict: Dict[Any, Union[np.ndarray, List[object]]], event: DecisionPayload, + ) -> Dict[Any, object]: + if action_dict["AGENT"] == self.num_pms: + return {"AGENT": PostponeAction(vm_id=event.vm_id, postpone_step=1)} + else: + return {"AGENT": AllocateAction(vm_id=event.vm_id, pm_id=action_dict["AGENT"][0])} + + def _get_reward(self, env_action_dict: Dict[Any, object], event: DecisionPayload, tick: int) -> Dict[Any, float]: + action = env_action_dict["AGENT"] + conf = reward_shaping_conf if self._env == self._learn_env else test_reward_shaping_conf + if isinstance(action, PostponeAction): # postponement + if np.sum(self._legal_pm_mask) != 1: + reward = -0.1 * conf["alpha"] + 0.0 * conf["beta"] + else: + reward = 0.0 * conf["alpha"] + 0.0 * conf["beta"] + else: + reward = self._get_allocation_reward(event, conf["alpha"], conf["beta"]) if event else .0 + return {"AGENT": np.float32(reward)} + + def _get_pm_state(self): + total_pm_info = self._env.snapshot_list["pms"][self._env.frame_index::pm_attributes] + total_pm_info = total_pm_info.reshape(self.num_pms, len(pm_attributes)) + + # normalize the attributes of pms' cpu and memory + self._max_cpu_capacity = np.max(total_pm_info[:, 0]) + self._max_memory_capacity = np.max(total_pm_info[:, 1]) + total_pm_info[:, 2] /= self._max_cpu_capacity + total_pm_info[:, 3] /= self._max_memory_capacity + + # get the remaining cpu and memory of the pms + remain_cpu = (1 - total_pm_info[:, 2]).reshape(1, self.num_pms, 1) + remain_memory = (1 - total_pm_info[:, 3]).reshape(1, self.num_pms, 1) + + # get the pms' information + total_pm_info = np.concatenate((remain_cpu, remain_memory), axis=2) # (1, num_pms, 2) + + # get the sequence pms' information + self._pm_state_history = np.concatenate((self._pm_state_history, total_pm_info), axis=0) + return self._pm_state_history[-pm_window_size:, :, :] # (win_size, num_pms, 2) + + def _get_vm_state(self, event): + return np.array([ + event.vm_cpu_cores_requirement / self._max_cpu_capacity, + event.vm_memory_requirement / self._max_memory_capacity, + (self._durations - self._env.tick) * 1.0 / 200, # TODO: CHANGE 200 TO SOMETHING CONFIGURABLE + self._env.business_engine._get_unit_price(event.vm_cpu_cores_requirement, event.vm_memory_requirement) + ]) + + def _get_allocation_reward(self, event: DecisionPayload, alpha: float, beta: float): + vm_unit_price = self._env.business_engine._get_unit_price( + event.vm_cpu_cores_requirement, event.vm_memory_requirement + ) + return (alpha + beta * vm_unit_price * min(self._durations - event.frame_index, event.remaining_buffer_time)) + + def _post_step(self, cache_element: CacheElement) -> None: + self._info["env_metric"] = {k: v for k, v in self._env.metrics.items() if k != "total_latency"} + self._info["env_metric"]["latency_due_to_agent"] = self._env.metrics["total_latency"].due_to_agent + self._info["env_metric"]["latency_due_to_resource"] = self._env.metrics["total_latency"].due_to_resource + if "actions_by_core_requirement" not in self._info: + self._info["actions_by_core_requirement"] = defaultdict(list) + if "action_sequence" not in self._info: + self._info["action_sequence"] = [] + + action = cache_element.action_dict["AGENT"] + if cache_element.state: + mask = cache_element.state[num_features:] + self._info["actions_by_core_requirement"][cache_element.event.vm_cpu_cores_requirement].append([action, mask]) + self._info["action_sequence"].append(action) + + def _post_eval_step(self, cache_element: CacheElement) -> None: + self._post_step(cache_element) + + def post_collect(self, info_list: list, ep: int) -> None: + # print the env metric from each rollout worker + for info in info_list: + print(f"env summary (episode {ep}): {info['env_metric']}") + + # print the average env metric + if len(info_list) > 1: + metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list) + avg_metric = {key: sum(tr["env_metric"][key] for tr in info_list) / num_envs for key in metric_keys} + print(f"average env metric (episode {ep}): {avg_metric}") + + def post_evaluate(self, info_list: list, ep: int) -> None: + # print the env metric from each rollout worker + for info in info_list: + print(f"env summary (evaluation episode {ep}): {info['env_metric']}") + + # print the average env metric + if len(info_list) > 1: + metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list) + avg_metric = {key: sum(tr["env_metric"][key] for tr in info_list) / num_envs for key in metric_keys} + print(f"average env metric (evaluation episode {ep}): {avg_metric}") + + for info in info_list: + core_requirement = info["actions_by_core_requirement"] + action_sequence = info["action_sequence"] + # plot action sequence + fig = plt.figure(figsize=(40, 32)) + ax = fig.add_subplot(1, 1, 1) + ax.plot(action_sequence) + fig.savefig(f"{plt_path}/action_sequence_{ep}") + plt.cla() + plt.close("all") + + # plot with legal action mask + fig = plt.figure(figsize=(40, 32)) + for idx, key in enumerate(core_requirement.keys()): + ax = fig.add_subplot(len(core_requirement.keys()), 1, idx + 1) + for i in range(len(core_requirement[key])): + if i == 0: + ax.plot(core_requirement[key][i][0] * core_requirement[key][i][1], label=str(key)) + ax.legend() + else: + ax.plot(core_requirement[key][i][0] * core_requirement[key][i][1]) + + fig.savefig(f"{plt_path}/values_with_legal_action_{ep}") + + plt.cla() + plt.close("all") + + # plot without legal actin mask + fig = plt.figure(figsize=(40, 32)) + + for idx, key in enumerate(core_requirement.keys()): + ax = fig.add_subplot(len(core_requirement.keys()), 1, idx + 1) + for i in range(len(core_requirement[key])): + if i == 0: + ax.plot(core_requirement[key][i][0], label=str(key)) + ax.legend() + else: + ax.plot(core_requirement[key][i][0]) + + fig.savefig(f"{plt_path}/values_without_legal_action_{ep}") + + plt.cla() + plt.close("all") diff --git a/examples/vm_scheduling/rl/rl_component_bundle.py b/examples/vm_scheduling/rl/rl_component_bundle.py new file mode 100644 index 000000000..14edf47ce --- /dev/null +++ b/examples/vm_scheduling/rl/rl_component_bundle.py @@ -0,0 +1,57 @@ +from functools import partial +from typing import Any, Callable, Dict, Optional + +from examples.vm_scheduling.rl.algorithms.ac import get_ac_policy +from examples.vm_scheduling.rl.algorithms.dqn import get_dqn_policy +from examples.vm_scheduling.rl.config import algorithm, env_conf, num_features, num_pms, state_dim, test_env_conf +from examples.vm_scheduling.rl.env_sampler import VMEnvSampler +from maro.rl.policy import AbsPolicy +from maro.rl.rl_component.rl_component_bundle import RLComponentBundle +from maro.rl.rollout import AbsEnvSampler +from maro.rl.training import AbsTrainer + + +class VMBundle(RLComponentBundle): + def get_env_config(self) -> dict: + return env_conf + + def get_test_env_config(self) -> Optional[dict]: + return test_env_conf + + def get_env_sampler(self) -> AbsEnvSampler: + return VMEnvSampler(self.env, self.test_env) + + def get_agent2policy(self) -> Dict[Any, str]: + return {"AGENT": f"{algorithm}.policy"} + + def get_policy_creator(self) -> Dict[str, Callable[[], AbsPolicy]]: + action_num = num_pms + 1 # action could be any PM or postponement, hence the plus 1 + + if algorithm == "ac": + policy_creator = { + f"{algorithm}.policy": partial( + get_ac_policy, state_dim, action_num, num_features, f"{algorithm}.policy", + ) + } + elif algorithm == "dqn": + policy_creator = { + f"{algorithm}.policy": partial( + get_dqn_policy, state_dim, action_num, num_features, f"{algorithm}.policy", + ) + } + else: + raise ValueError(f"Unsupported algorithm: {algorithm}") + + return policy_creator + + def get_trainer_creator(self) -> Dict[str, Callable[[], AbsTrainer]]: + if algorithm == "ac": + from .algorithms.ac import get_ac, get_ac_policy + trainer_creator = {algorithm: partial(get_ac, state_dim, num_features, algorithm)} + elif algorithm == "dqn": + from .algorithms.dqn import get_dqn, get_dqn_policy + trainer_creator = {algorithm: partial(get_dqn, algorithm)} + else: + raise ValueError(f"Unsupported algorithm: {algorithm}") + + return trainer_creator diff --git a/examples/vm_scheduling/rule_based_algorithm/__init__.py b/examples/vm_scheduling/rule_based_algorithm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/maro/README.rst b/maro/README.rst index a63d78a50..abd6ede7c 100644 --- a/maro/README.rst +++ b/maro/README.rst @@ -41,13 +41,13 @@ .. image:: https://github.com/microsoft/maro/workflows/test/badge.svg - :target: https://github.com/microsoft/maro/actions?query=workflow%3Atest - :alt: test + :target: https://github.com/microsoft/maro/actions?query=workflow%3Atest + :alt: test .. image:: https://github.com/microsoft/maro/workflows/build/badge.svg - :target: https://github.com/microsoft/maro/actions?query=workflow%3Abuild - :alt: build + :target: https://github.com/microsoft/maro/actions?query=workflow%3Abuild + :alt: build .. image:: https://github.com/microsoft/maro/workflows/docker/badge.svg @@ -56,8 +56,8 @@ .. image:: https://readthedocs.org/projects/maro/badge/?version=latest - :target: https://maro.readthedocs.io/ - :alt: docs + :target: https://maro.readthedocs.io/ + :alt: docs .. image:: https://img.shields.io/pypi/v/pymaro @@ -142,6 +142,69 @@ ================================================================================================================ + +.. image:: https://raw.githubusercontent.com/microsoft/maro/master/docs/source/images/badges/vm_scheduling.svg + :target: https://maro.readthedocs.io/en/latest/scenarios/vm_scheduling.html + :alt: VM Scheduling + + +.. image:: https://img.shields.io/gitter/room/microsoft/maro + :target: https://gitter.im/Microsoft/MARO# + :alt: Gitter + + +.. image:: https://raw.githubusercontent.com/microsoft/maro/master/docs/source/images/badges/stack_overflow.svg + :target: https://stackoverflow.com/questions/ask?tags=maro + :alt: Stack Overflow + + +.. image:: https://img.shields.io/github/release-date-pre/microsoft/maro + :target: https://github.com/microsoft/maro/releases + :alt: Releases + + +.. image:: https://img.shields.io/github/commits-since/microsoft/maro/latest/master + :target: https://github.com/microsoft/maro/commits/master + :alt: Commits + + +.. image:: https://github.com/microsoft/maro/workflows/vulnerability%20scan/badge.svg + :target: https://github.com/microsoft/maro/actions?query=workflow%3A%22vulnerability+scan%22 + :alt: Vulnerability Scan + + +.. image:: https://github.com/microsoft/maro/workflows/lint/badge.svg + :target: https://github.com/microsoft/maro/actions?query=workflow%3Alint + :alt: Lint + + +.. image:: https://img.shields.io/codecov/c/github/microsoft/maro + :target: https://codecov.io/gh/microsoft/maro + :alt: Coverage + + +.. image:: https://img.shields.io/pypi/dm/pymaro + :target: https://pypi.org/project/pymaro/#files + :alt: Downloads + + +.. image:: https://img.shields.io/docker/pulls/maro2020/maro + :target: https://hub.docker.com/repository/docker/maro2020/maro + :alt: Docker Pulls + + +.. image:: https://raw.githubusercontent.com/microsoft/maro/master/docs/source/images/badges/play_with_maro.svg + :target: https://hub.docker.com/r/maro2020/maro + :alt: Play with MARO + + + +.. image:: https://github.com/microsoft/maro/blob/master/docs/source/images/logo.svg + :target: https://maro.readthedocs.io/en/latest/ + :alt: MARO LOGO + +================================================================================================================ + Multi-Agent Resource Optimization (MARO) platform is an instance of Reinforcement learning as a Service (RaaS) for real-world resource optimization. It can be applied to many important industrial domains, such as `container inventory @@ -172,18 +235,18 @@ Contents -------- .. list-table:: - :header-rows: 1 + :header-rows: 1 - * - File/folder - - Description - * - ``maro`` - - MARO source code. - * - ``docs`` - - MARO docs, it is host on `readthedocs `_. - * - ``examples`` - - Showcase of MARO. - * - ``notebooks`` - - MARO quick-start notebooks. + * - File/folder + - Description + * - ``maro`` + - MARO source code. + * - ``docs`` + - MARO docs, it is host on `readthedocs `_. + * - ``examples`` + - Showcase of MARO. + * - ``notebooks`` + - MARO quick-start notebooks. *Try `MARO playground <#run-playground>`_ to have a quick experience.* @@ -199,17 +262,17 @@ Install MARO from `PyPI `_ .. code-block:: sh - pip install pymaro + pip install pymaro * Windows .. code-block:: powershell - # Install torch first, if you don't have one. - pip install torch===1.6.0 torchvision===0.7.0 -f https://download.pytorch.org/whl/torch_stable.html + # Install torch first, if you don't have one. + pip install torch===1.6.0 torchvision===0.7.0 -f https://download.pytorch.org/whl/torch_stable.html - pip install pymaro + pip install pymaro Install MARO from Source ------------------------ @@ -235,9 +298,9 @@ Install MARO from Source .. code-block:: sh - # If your environment is not clean, create a virtual environment firstly. - python -m venv maro_venv - source ./maro_venv/bin/activate + # If your environment is not clean, create a virtual environment firstly. + python -m venv maro_venv + source ./maro_venv/bin/activate * Windows @@ -267,16 +330,16 @@ Install MARO from Source .. code-block:: sh - # Install MARO from source. - bash scripts/install_maro.sh + # Install MARO from source. + bash scripts/install_maro.sh * Windows .. code-block:: powershell - # Install MARO from source. - .\scripts\install_maro.bat + # Install MARO from source. + .\scripts\install_maro.bat * *Notes: If your package is not found, remember to set your PYTHONPATH* @@ -300,16 +363,16 @@ Quick Example .. code-block:: python - from maro.simulator import Env + from maro.simulator import Env - env = Env(scenario="cim", topology="toy.5p_ssddd_l0.0", start_tick=0, durations=100) + env = Env(scenario="cim", topology="toy.5p_ssddd_l0.0", start_tick=0, durations=100) - metrics, decision_event, is_done = env.step(None) + metrics, decision_event, is_done = env.step(None) - while not is_done: - metrics, decision_event, is_done = env.step(None) + while not is_done: + metrics, decision_event, is_done = env.step(None) - print(f"environment metrics: {env.metrics}") + print(f"environment metrics: {env.metrics}") `Environment Visualization `_ ------------------------------------------------------------------------- @@ -382,8 +445,8 @@ Run Playground .. code-block:: sh - # Build playground image. - bash ./scripts/build_playground.sh + # Build playground image. + bash ./scripts/build_playground.sh # Run playground container. # Redis commander (GUI for redis) -> http://127.0.0.1:40009 @@ -395,8 +458,8 @@ Run Playground .. code-block:: powershell - # Build playground image. - .\scripts\build_playground.bat + # Build playground image. + .\scripts\build_playground.bat # Run playground container. # Redis commander (GUI for redis) -> http://127.0.0.1:40009 diff --git a/maro/backends/frame.pyx b/maro/backends/frame.pyx index d1e61a103..481f45e46 100644 --- a/maro/backends/frame.pyx +++ b/maro/backends/frame.pyx @@ -74,6 +74,15 @@ def node(name: str): return node_dec +def try_get_attribute(target, name, default=None): + try: + attr = object.__getattribute__(target, name) + + return attr + except: + return default + + cdef class NodeAttribute: def __cinit__(self, object dtype = None, SLOT_INDEX slot_num = 1, is_const = False, is_list = False): # Check the type of dtype, used to compact with old version @@ -532,6 +541,8 @@ cdef class FrameBase: else: node._is_deleted = False + # Also + cpdef void take_snapshot(self, INT tick) except *: """Take snapshot for specified point (tick) for current frame. diff --git a/maro/cli/k8s/aks/aks_commands.py b/maro/cli/k8s/aks/aks_commands.py new file mode 100644 index 000000000..22637ca10 --- /dev/null +++ b/maro/cli/k8s/aks/aks_commands.py @@ -0,0 +1,317 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import base64 +import json +import os +import shutil +from os.path import abspath, dirname, expanduser, join + +import yaml + +from maro.cli.utils import docker as docker_utils +from maro.cli.utils.azure import storage as azure_storage_utils +from maro.cli.utils.azure.aks import attach_acr +from maro.cli.utils.azure.deployment import create_deployment +from maro.cli.utils.azure.general import connect_to_aks, get_acr_push_permissions, set_env_credentials +from maro.cli.utils.azure.resource_group import create_resource_group, delete_resource_group +from maro.cli.utils.common import show_log +from maro.rl.workflows.config import ConfigParser +from maro.utils.logger import CliLogger +from maro.utils.utils import LOCAL_MARO_ROOT + +from ..utils import k8s_manifest_generator, k8s_ops + +# metadata +CLI_AKS_PATH = dirname(abspath(__file__)) +TEMPLATE_PATH = join(CLI_AKS_PATH, "template.json") +NVIDIA_PLUGIN_PATH = join(CLI_AKS_PATH, "create_nvidia_plugin", "nvidia-device-plugin.yml") +LOCAL_ROOT = expanduser("~/.maro/aks") +DEPLOYMENT_CONF_PATH = os.path.join(LOCAL_ROOT, "conf.json") +DOCKER_FILE_PATH = join(LOCAL_MARO_ROOT, "docker_files", "dev.df") +DOCKER_IMAGE_NAME = "maro-aks" +REDIS_HOST = "maro-redis" +REDIS_PORT = 6379 +ADDRESS_REGISTRY_NAME = "address-registry" +ADDRESS_REGISTRY_PORT = 6379 +K8S_SECRET_NAME = "azure-secret" + +# display +NO_DEPLOYMENT_MSG = "No Kubernetes deployment on Azure found. Use 'maro aks init' to create a deployment first" +NO_JOB_MSG = "No job named {} has been scheduled. Use 'maro aks job add' to add the job first." +JOB_EXISTS_MSG = "A job named {} has already been scheduled." + +logger = CliLogger(name=__name__) + + +# helper functions +def get_resource_group_name(deployment_name: str): + return f"rg-{deployment_name}" + + +def get_acr_name(deployment_name: str): + return f"crmaro{deployment_name}" + + +def get_acr_server_name(acr_name: str): + return f"{acr_name}.azurecr.io" + + +def get_docker_image_name_in_acr(acr_name: str, docker_image_name: str): + return f"{get_acr_server_name(acr_name)}/{docker_image_name}" + + +def get_aks_name(deployment_name: str): + return f"aks-maro-{deployment_name}" + + +def get_agentpool_name(deployment_name: str): + return f"ap{deployment_name}" + + +def get_fileshare_name(deployment_name: str): + return f"fs-{deployment_name}" + + +def get_storage_account_name(deployment_name: str): + return f"stscenario{deployment_name}" + + +def get_virtual_network_name(location: str, deployment_name: str): + return f"vnet-prod-{location}-{deployment_name}" + + +def get_local_job_path(job_name: str): + return os.path.join(LOCAL_ROOT, job_name) + + +def get_storage_account_secret(resource_group_name: str, storage_account_name: str, namespace: str): + storage_account_keys = azure_storage_utils.get_storage_account_keys(resource_group_name, storage_account_name) + storage_key = storage_account_keys[0]["value"] + secret_data = { + "azurestorageaccountname": base64.b64encode(storage_account_name.encode()).decode(), + "azurestorageaccountkey": base64.b64encode(bytes(storage_key.encode())).decode() + } + k8s_ops.create_secret(K8S_SECRET_NAME, secret_data, namespace) + + +def get_resource_params(deployment_conf: dict) -> dict: + """Create ARM parameters for Azure resource deployment (). + + See https://docs.microsoft.com/en-us/azure/azure-resource-manager/templates/overview for details. + + Args: + deployment_conf (dict): Configuration dict for deployment on Azure. + + Returns: + dict: parameter dict, should be exported to json. + """ + name = deployment_conf["name"] + return { + "acrName": get_acr_name(name), + "acrSku": deployment_conf["container_registry_service_tier"], + "systemPoolVMCount": deployment_conf["resources"]["k8s"]["vm_count"], + "systemPoolVMSize": deployment_conf["resources"]["k8s"]["vm_size"], + "userPoolName": get_agentpool_name(name), + "userPoolVMCount": deployment_conf["resources"]["app"]["vm_count"], + "userPoolVMSize": deployment_conf["resources"]["app"]["vm_size"], + "aksName": get_aks_name(name), + "location": deployment_conf["location"], + "storageAccountName": get_storage_account_name(name), + "fileShareName": get_fileshare_name(name) + # "virtualNetworkName": get_virtual_network_name(deployment_conf["location"], name) + } + + +def prepare_docker_image_and_push_to_acr(image_name: str, context: str, docker_file_path: str, acr_name: str): + # build and tag docker image locally and push to the Azure Container Registry + if not docker_utils.image_exists(image_name): + docker_utils.build_image(context, docker_file_path, image_name) + + get_acr_push_permissions(os.environ["AZURE_CLIENT_ID"], acr_name) + docker_utils.push(image_name, get_acr_server_name(acr_name)) + + +def start_redis_service_in_aks(host: str, port: int, namespace: str): + k8s_ops.load_config() + k8s_ops.create_namespace(namespace) + k8s_ops.create_deployment(k8s_manifest_generator.get_redis_deployment_manifest(host, port), namespace) + k8s_ops.create_service(k8s_manifest_generator.get_redis_service_manifest(host, port), namespace) + + +# CLI command functions +def init(deployment_conf_path: str, **kwargs): + """Prepare Azure resources needed for an AKS cluster using a YAML configuration file. + + The configuration file template can be found in cli/k8s/aks/conf.yml. Use the Azure CLI to log into + your Azure account (az login ...) and the the Azure Container Registry (az acr login ...) first. + + Args: + deployment_conf_path (str): Path to the deployment configuration file. + """ + with open(deployment_conf_path, "r") as fp: + deployment_conf = yaml.safe_load(fp) + + subscription = deployment_conf["azure_subscription"] + name = deployment_conf["name"] + if os.path.isfile(DEPLOYMENT_CONF_PATH): + logger.warning(f"Deployment {name} has already been created") + return + + os.makedirs(LOCAL_ROOT, exist_ok=True) + resource_group_name = get_resource_group_name(name) + try: + # Set credentials as environment variables + set_env_credentials(LOCAL_ROOT, f"sp-{name}") + + # create resource group + resource_group = create_resource_group(subscription, resource_group_name, deployment_conf["location"]) + logger.info_green(f"Provisioned resource group {resource_group.name} in {resource_group.location}") + + # Create ARM parameters and start deployment + logger.info("Creating Azure resources...") + resource_params = get_resource_params(deployment_conf) + with open(TEMPLATE_PATH, 'r') as fp: + template = json.load(fp) + + create_deployment(subscription, resource_group_name, name, template, resource_params) + + # Attach ACR to AKS + aks_name, acr_name = resource_params["aksName"], resource_params["acrName"] + attach_acr(resource_group_name, aks_name, acr_name) + connect_to_aks(resource_group_name, aks_name) + + # build and tag docker image locally and push to the Azure Container Registry + logger.info("Preparing docker image...") + prepare_docker_image_and_push_to_acr(DOCKER_IMAGE_NAME, LOCAL_MARO_ROOT, DOCKER_FILE_PATH, acr_name) + + # start the Redis service in the k8s cluster in the deployment namespace and expose it + logger.info("Starting Redis service in the k8s cluster...") + start_redis_service_in_aks(REDIS_HOST, REDIS_PORT, name) + + # Dump the deployment configuration + with open(DEPLOYMENT_CONF_PATH, "w") as fp: + json.dump({ + "name": name, + "subscription": subscription, + "resource_group": resource_group_name, + "resources": resource_params + }, fp) + logger.info_green(f"Cluster '{name}' is created") + except Exception as e: + # If failed, remove details folder, then raise + shutil.rmtree(LOCAL_ROOT) + logger.error_red(f"Deployment {name} failed due to {e}, rolling back...") + delete_resource_group(subscription, resource_group_name) + except KeyboardInterrupt: + shutil.rmtree(LOCAL_ROOT) + logger.error_red(f"Deployment {name} aborted, rolling back...") + delete_resource_group(subscription, resource_group_name) + + +def add_job(conf_path: dict, **kwargs): + if not os.path.isfile(DEPLOYMENT_CONF_PATH): + logger.error_red(NO_DEPLOYMENT_MSG) + return + + parser = ConfigParser(conf_path) + job_name = parser.config["job"] + local_job_path = get_local_job_path(job_name) + if os.path.isdir(local_job_path): + logger.error_red(JOB_EXISTS_MSG.format(job_name)) + return + + os.makedirs(local_job_path) + with open(DEPLOYMENT_CONF_PATH, "r") as fp: + deployment_conf = json.load(fp) + + resource_group_name, resource_name = deployment_conf["resource_group"], deployment_conf["resources"] + fileshare = azure_storage_utils.get_fileshare(resource_name["storageAccountName"], resource_name["fileShareName"]) + job_dir = azure_storage_utils.get_directory(fileshare, job_name) + scenario_path = parser.config["scenario_path"] + logger.info(f"Uploading local directory {scenario_path}...") + azure_storage_utils.upload_to_fileshare(job_dir, scenario_path, name="scenario") + azure_storage_utils.get_directory(job_dir, "checkpoints") + azure_storage_utils.get_directory(job_dir, "logs") + + # Define mount volumes, i.e., scenario code, checkpoints, logs and load point + job_path_in_share = f"{resource_name['fileShareName']}/{job_name}" + volumes = [ + k8s_manifest_generator.get_azurefile_volume_spec(name, f"{job_path_in_share}/{name}", K8S_SECRET_NAME) + for name in ["scenario", "logs", "checkpoints"] + ] + + if "load_path" in parser.config["training"]: + load_path = parser.config["training"]["load_path"] + logger.info(f"Uploading local model directory {load_path}...") + azure_storage_utils.upload_to_fileshare(job_dir, load_path, name="loadpoint") + volumes.append( + k8s_manifest_generator.get_azurefile_volume_spec( + "loadpoint", f"{job_path_in_share}/loadpoint", K8S_SECRET_NAME) + ) + + # Start k8s jobs + k8s_ops.load_config() + k8s_ops.create_namespace(job_name) + get_storage_account_secret(resource_group_name, resource_name["storageAccountName"], job_name) + k8s_ops.create_service( + k8s_manifest_generator.get_cross_namespace_service_access_manifest( + ADDRESS_REGISTRY_NAME, REDIS_HOST, deployment_conf["name"], ADDRESS_REGISTRY_PORT + ), job_name + ) + for component_name, (script, env) in parser.get_job_spec(containerize=True).items(): + container_spec = k8s_manifest_generator.get_container_spec( + get_docker_image_name_in_acr(resource_name["acrName"], DOCKER_IMAGE_NAME), + component_name, + script, + env, + volumes + ) + manifest = k8s_manifest_generator.get_job_manifest( + resource_name["userPoolName"], + component_name, + container_spec, + volumes + ) + k8s_ops.create_job(manifest, job_name) + + +def remove_jobs(job_names: str, **kwargs): + if not os.path.isfile(DEPLOYMENT_CONF_PATH): + logger.error_red(NO_DEPLOYMENT_MSG) + return + + k8s_ops.load_config() + for job_name in job_names: + local_job_path = get_local_job_path(job_name) + if not os.path.isdir(local_job_path): + logger.error_red(NO_JOB_MSG.format(job_name)) + return + + k8s_ops.delete_job(job_name) + + +def get_job_logs(job_name: str, tail: int = -1, **kwargs): + with open(DEPLOYMENT_CONF_PATH, "r") as fp: + deployment_conf = json.load(fp) + + local_log_path = os.path.join(get_local_job_path(job_name), "log") + resource_name = deployment_conf["resources"] + fileshare = azure_storage_utils.get_fileshare(resource_name["storageAccountName"], resource_name["fileShareName"]) + job_dir = azure_storage_utils.get_directory(fileshare, job_name) + log_dir = azure_storage_utils.get_directory(job_dir, "logs") + azure_storage_utils.download_from_fileshare(log_dir, f"{job_name}.log", local_log_path) + show_log(local_log_path, tail=tail) + + +def exit(**kwargs): + try: + with open(DEPLOYMENT_CONF_PATH, "r") as fp: + deployment_conf = json.load(fp) + except FileNotFoundError: + logger.error(NO_DEPLOYMENT_MSG) + return + + name = deployment_conf["name"] + set_env_credentials(LOCAL_ROOT, f"sp-{name}") + delete_resource_group(deployment_conf["subscription"], deployment_conf["resource_group"]) diff --git a/maro/cli/k8s/aks/conf.yml b/maro/cli/k8s/aks/conf.yml new file mode 100644 index 000000000..3f249d47e --- /dev/null +++ b/maro/cli/k8s/aks/conf.yml @@ -0,0 +1,12 @@ +mode: "" +azure_subscription: your_azure_subscription_id +name: your_deployment_name +location: your_azure_service_location +container_registry_service_tier: Standard # "Basic", "Standard", "Premium", see https://docs.microsoft.com/en-us/azure/container-registry/container-registry-skus for details +resources: + k8s: + vm_size: Standard_DS2_v2 # https://docs.microsoft.com/en-us/azure/virtual-machines/sizes, https://docs.microsoft.com/en-us/azure/aks/quotas-skus-regions + vm_count: 1 # must be at least 2 for k8s to function properly. + app: + vm_size: Standard_DS2_v2 # https://docs.microsoft.com/en-us/azure/virtual-machines/sizes, https://docs.microsoft.com/en-us/azure/aks/quotas-skus-regions + vm_count: 1 \ No newline at end of file diff --git a/maro/cli/k8s/aks/parameters.json b/maro/cli/k8s/aks/parameters.json new file mode 100644 index 000000000..9062b86ac --- /dev/null +++ b/maro/cli/k8s/aks/parameters.json @@ -0,0 +1,33 @@ +{ + "$schema": "https://schema.management.azure.com/schemas/2015-01-01/deploymentParameters.json#", + "contentVersion": "1.1.0.0", + "parameters": { + "acrName": { + "value": "myacr" + }, + "acrSku": { + "value": "Basic" + }, + "agentCount": { + "value": 1 + }, + "agentVMSize": { + "value": "standard_a2_v2" + }, + "clusterName": { + "value": "myaks" + }, + "fileShareName": { + "value": "myfileshare" + }, + "location": { + "value": "East US" + }, + "storageAccountName": { + "value": "mystorage" + }, + "virtualNetworkName": { + "value": "myvnet" + } + } +} diff --git a/maro/cli/k8s/aks/template.json b/maro/cli/k8s/aks/template.json new file mode 100644 index 000000000..e61bdff12 --- /dev/null +++ b/maro/cli/k8s/aks/template.json @@ -0,0 +1,157 @@ +{ + "$schema": "https://schema.management.azure.com/schemas/2019-04-01/deploymentTemplate.json#", + "contentVersion": "1.1.0.0", + "parameters": { + "acrName": { + "type": "string", + "minLength": 5, + "maxLength": 50, + "metadata": { + "description": "Name of your Azure Container Registry" + } + }, + "acrSku": { + "type": "string", + "metadata": { + "description": "Tier of your Azure Container Registry." + }, + "defaultValue": "Standard", + "allowedValues": [ + "Basic", + "Standard", + "Premium" + ] + }, + "systemPoolVMCount": { + "type": "int", + "metadata": { + "description": "The number of VMs allocated for running the k8s system components." + }, + "minValue": 1, + "maxValue": 50 + }, + "systemPoolVMSize": { + "type": "string", + "metadata": { + "description": "Virtual Machine size for running the k8s system components." + } + }, + "userPoolName": { + "type": "string", + "metadata": { + "description": "Name of the user node pool." + } + }, + "userPoolVMCount": { + "type": "int", + "metadata": { + "description": "The number of VMs allocated for running the user appplication." + }, + "minValue": 1, + "maxValue": 50 + }, + "userPoolVMSize": { + "type": "string", + "metadata": { + "description": "Virtual Machine size for running the user application." + } + }, + "aksName": { + "type": "string", + "metadata": { + "description": "Name of the Managed Cluster resource." + } + }, + "location": { + "type": "string", + "metadata": { + "description": "Location of the Managed Cluster resource." + } + }, + "storageAccountName": { + "type": "string", + "metadata": { + "description": "Azure storage account name." + } + }, + "fileShareName": { + "type": "string", + "metadata": { + "description": "Azure file share name." + } + } + }, + "resources": [ + { + "name": "[parameters('acrName')]", + "type": "Microsoft.ContainerRegistry/registries", + "apiVersion": "2021-09-01", + "location": "[parameters('location')]", + "sku": { + "name": "[parameters('acrSku')]" + }, + "properties": { + } + }, + { + "name": "[parameters('aksName')]", + "type": "Microsoft.ContainerService/managedClusters", + "apiVersion": "2021-10-01", + "location": "[parameters('location')]", + "properties": { + "dnsPrefix": "maro", + "agentPoolProfiles": [ + { + "name": "system", + "osDiskSizeGB": 0, + "count": "[parameters('systemPoolVMCount')]", + "vmSize": "[parameters('systemPoolVMSize')]", + "osType": "Linux", + "storageProfile": "ManagedDisks", + "mode": "System", + "type": "VirtualMachineScaleSets" + }, + { + "name": "[parameters('userPoolName')]", + "osDiskSizeGB": 0, + "count": "[parameters('userPoolVMCount')]", + "vmSize": "[parameters('userPoolVMSize')]", + "osType": "Linux", + "storageProfile": "ManagedDisks", + "mode": "User", + "type": "VirtualMachineScaleSets" + } + ], + "networkProfile": { + "networkPlugin": "azure", + "loadBalancerSku": "standard" + } + }, + "identity": { + "type": "SystemAssigned" + } + }, + { + "type": "Microsoft.Storage/storageAccounts", + "apiVersion": "2021-08-01", + "name": "[parameters('storageAccountName')]", + "location": "[parameters('location')]", + "kind": "StorageV2", + "sku": { + "name": "Standard_LRS", + "tier": "Standard" + }, + "properties": { + "accessTier": "Hot" + } + }, + { + "type": "Microsoft.Storage/storageAccounts/fileServices/shares", + "apiVersion": "2021-04-01", + "name": "[concat(parameters('storageAccountName'), '/default/', parameters('fileShareName'))]", + "dependsOn": [ + "[resourceId('Microsoft.Storage/storageAccounts', parameters('storageAccountName'))]" + ] + } + ] +} diff --git a/maro/cli/k8s/lib/modes/aks/create_aks_cluster/template.json b/maro/cli/k8s/lib/modes/aks/create_aks_cluster/template.json index 1e77a9802..b9db8e344 100644 --- a/maro/cli/k8s/lib/modes/aks/create_aks_cluster/template.json +++ b/maro/cli/k8s/lib/modes/aks/create_aks_cluster/template.json @@ -22,18 +22,6 @@ "Premium" ] }, - "adminPublicKey": { - "type": "string", - "metadata": { - "description": "Configure all linux machines with the SSH RSA public key string. Your key should include three parts, for example 'ssh-rsa AAAAB...snip...UcyupgH azureuser@linuxvm'" - } - }, - "adminUsername": { - "type": "string", - "metadata": { - "description": "User name for the Linux Virtual Machines." - } - }, "agentCount": { "type": "int", "metadata": { @@ -87,7 +75,7 @@ "resources": [ { "type": "Microsoft.Storage/storageAccounts/fileServices/shares", - "apiVersion": "2020-08-01-preview", + "apiVersion": "2021-04-01", "name": "[concat(parameters('storageAccountName'), '/default/', parameters('fileShareName'))]", "dependsOn": [ "[variables('stvmId')]" @@ -96,7 +84,7 @@ { "name": "[parameters('acrName')]", "type": "Microsoft.ContainerRegistry/registries", - "apiVersion": "2020-11-01-preview", + "apiVersion": "2021-09-01", "location": "[parameters('location')]", "sku": { "name": "[parameters('acrSku')]" @@ -107,7 +95,7 @@ { "name": "[parameters('clusterName')]", "type": "Microsoft.ContainerService/managedClusters", - "apiVersion": "2020-03-01", + "apiVersion": "2021-10-01", "location": "[parameters('location')]", "dependsOn": [ "[variables('vnetId')]" @@ -127,16 +115,6 @@ "type": "VirtualMachineScaleSets" } ], - "linuxProfile": { - "adminUsername": "[parameters('adminUsername')]", - "ssh": { - "publicKeys": [ - { - "keyData": "[parameters('adminPublicKey')]" - } - ] - } - }, "networkProfile": { "networkPlugin": "azure", "loadBalancerSku": "standard" @@ -148,7 +126,7 @@ }, { "type": "Microsoft.Storage/storageAccounts", - "apiVersion": "2020-08-01-preview", + "apiVersion": "2021-08-01", "name": "[parameters('storageAccountName')]", "location": "[parameters('location')]", "kind": "StorageV2", @@ -163,7 +141,7 @@ { "name": "[parameters('virtualNetworkName')]", "type": "Microsoft.Network/virtualNetworks", - "apiVersion": "2020-04-01", + "apiVersion": "2020-11-01", "location": "[parameters('location')]", "properties": { "addressSpace": { diff --git a/maro/cli/k8s/utils/k8s_manifest_generator.py b/maro/cli/k8s/utils/k8s_manifest_generator.py new file mode 100644 index 000000000..d46f03b4c --- /dev/null +++ b/maro/cli/k8s/utils/k8s_manifest_generator.py @@ -0,0 +1,106 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import List + +from maro.cli.utils.common import format_env_vars + + +def get_job_manifest(agent_pool_name: str, component_name: str, container_spec: dict, volumes: List[dict]): + return { + "metadata": {"name": component_name}, + "spec": { + "template": { + "spec": { + "nodeSelector": {"agentpool": agent_pool_name}, + "restartPolicy": "Never", + "volumes": volumes, + "containers": [container_spec] + } + } + } + } + + +def get_azurefile_volume_spec(name: str, share_name: str, secret_name: str): + return { + "name": name, + "azureFile": { + "secretName": secret_name, + "shareName": share_name, + "readOnly": False + } + } + + +def get_container_spec(image_name: str, component_name: str, script: str, env: dict, volumes): + common_container_spec = { + "image": image_name, + "imagePullPolicy": "Always", + "volumeMounts": [{"name": vol["name"], "mountPath": f"/{vol['name']}"} for vol in volumes] + } + return { + **common_container_spec, + **{ + "name": component_name, + "command": ["python3", script], + "env": format_env_vars(env, mode="k8s") + } + } + + +def get_redis_deployment_manifest(host: str, port: int): + return { + "metadata": { + "name": host, + "labels": {"app": "redis"} + }, + "spec": { + "selector": { + "matchLabels": {"app": "redis"} + }, + "replicas": 1, + "template": { + "metadata": { + "labels": {"app": "redis"} + }, + "spec": { + "containers": [ + { + "name": "master", + "image": "redis:6", + "ports": [{"containerPort": port}] + } + ] + } + } + } + } + + +def get_redis_service_manifest(host: str, port: int): + return { + "metadata": { + "name": host, + "labels": {"app": "redis"} + }, + "spec": { + "ports": [{"port": port, "targetPort": port}], + "selector": {"app": "redis"} + } + } + + +def get_cross_namespace_service_access_manifest( + service_name: str, target_service_name: str, target_service_namespace: str, target_service_port: int +): + return { + "metadata": { + "name": service_name, + }, + "spec": { + "type": "ExternalName", + "externalName": f"{target_service_name}.{target_service_namespace}.svc.cluster.local", + "ports": [{"port": target_service_port}] + } + } diff --git a/maro/cli/k8s/utils/k8s_ops.py b/maro/cli/k8s/utils/k8s_ops.py new file mode 100644 index 000000000..8af7fd33d --- /dev/null +++ b/maro/cli/k8s/utils/k8s_ops.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import kubernetes +from kubernetes import client, config + + +def load_config(): + config.load_kube_config() + + +def create_namespace(namespace: str): + body = client.V1Namespace() + body.metadata = client.V1ObjectMeta(name=namespace) + try: + client.CoreV1Api().create_namespace(body) + except kubernetes.client.exceptions.ApiException: + pass + + +def create_deployment(conf: dict, namespace: str): + client.AppsV1Api().create_namespaced_deployment(namespace, conf) + + +def create_service(conf: dict, namespace: str): + client.CoreV1Api().create_namespaced_service(namespace, conf) + + +def create_job(conf: dict, namespace: str): + client.BatchV1Api().create_namespaced_job(namespace, conf) + + +def create_secret(name: str, data: dict, namespace: str): + client.CoreV1Api().create_namespaced_secret( + body=client.V1Secret(metadata=client.V1ObjectMeta(name=name), data=data), + namespace=namespace + ) + + +def delete_job(namespace: str): + client.BatchV1Api().delete_collection_namespaced_job(namespace) + client.CoreV1Api().delete_namespace(namespace) + + +def describe_job(namespace: str): + client.CoreV1Api().read_namespace(namespace) diff --git a/maro/cli/process/__init__.py b/maro/cli/local/__init__.py similarity index 100% rename from maro/cli/process/__init__.py rename to maro/cli/local/__init__.py diff --git a/maro/cli/local/commands.py b/maro/cli/local/commands.py new file mode 100644 index 000000000..26a4f7075 --- /dev/null +++ b/maro/cli/local/commands.py @@ -0,0 +1,253 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json +import shutil +import subprocess +import sys +import time +from os import makedirs +from os.path import abspath, dirname, exists, expanduser, join + +import redis +import yaml + +from maro.cli.utils.common import close_by_pid, show_log +from maro.rl.workflows.config import ConfigParser +from maro.utils.logger import CliLogger +from maro.utils.utils import LOCAL_MARO_ROOT + +from .utils import ( + JobStatus, RedisHashKey, start_redis, start_rl_job, start_rl_job_with_docker_compose, stop_redis, + stop_rl_job_with_docker_compose +) + +# metadata +LOCAL_ROOT = expanduser("~/.maro/local") +SESSION_STATE_PATH = join(LOCAL_ROOT, "session.json") +DOCKERFILE_PATH = join(LOCAL_MARO_ROOT, "docker_files", "dev.df") +DOCKER_IMAGE_NAME = "maro-local" +DOCKER_NETWORK = "MAROLOCAL" + +# display +NO_JOB_MANAGER_MSG = """No job manager found. Run "maro local init" to start the job manager first.""" +NO_JOB_MSG = """No job named {} found. Run "maro local job ls" to view existing jobs.""" +JOB_LS_TEMPLATE = "{JOB:12}{STATUS:15}{STARTED:20}" + +logger = CliLogger(name="MARO-LOCAL") + + +# helper functions +def get_redis_conn(port=None): + if port is None: + try: + with open(SESSION_STATE_PATH, "r") as fp: + port = json.load(fp)["port"] + except FileNotFoundError: + logger.error(NO_JOB_MANAGER_MSG) + return + + try: + redis_conn = redis.Redis(host="localhost", port=port) + redis_conn.ping() + return redis_conn + except redis.exceptions.ConnectionError: + logger.error(NO_JOB_MANAGER_MSG) + + +# Functions executed on CLI commands +def run(conf_path: str, containerize: bool = False, evaluate_only: bool = False, **kwargs): + # Load job configuration file + parser = ConfigParser(conf_path) + if containerize: + try: + start_rl_job_with_docker_compose( + parser, LOCAL_MARO_ROOT, DOCKERFILE_PATH, DOCKER_IMAGE_NAME, evaluate_only=evaluate_only, + ) + except KeyboardInterrupt: + stop_rl_job_with_docker_compose(parser.config["job"], LOCAL_MARO_ROOT) + else: + try: + start_rl_job(parser, LOCAL_MARO_ROOT, evaluate_only=evaluate_only) + except KeyboardInterrupt: + sys.exit(1) + + +def init( + port: int = 19999, + max_running: int = 3, + query_every: int = 5, + timeout: int = 3, + containerize: bool = False, + **kwargs +): + if exists(SESSION_STATE_PATH): + with open(SESSION_STATE_PATH, "r") as fp: + session_state = json.load(fp) + logger.warning( + f"Local job manager is already running at port {session_state['port']}. " + f"Run 'maro local job add/rm' to add / remove jobs." + ) + return + + start_redis(port) + + # Start job manager + command = ["python", join(dirname(abspath(__file__)), 'job_manager.py')] + job_manager = subprocess.Popen( + command, + env={ + "PYTHONPATH": LOCAL_MARO_ROOT, + "MAX_RUNNING": str(max_running), + "QUERY_EVERY": str(query_every), + "SIGTERM_TIMEOUT": str(timeout), + "CONTAINERIZE": str(containerize), + "REDIS_PORT": str(port), + "LOCAL_MARO_ROOT": LOCAL_MARO_ROOT, + "DOCKER_IMAGE_NAME": DOCKER_IMAGE_NAME, + "DOCKERFILE_PATH": DOCKERFILE_PATH + } + ) + + # Dump environment setting + makedirs(LOCAL_ROOT, exist_ok=True) + with open(SESSION_STATE_PATH, "w") as fp: + json.dump({"port": port, "job_manager_pid": job_manager.pid, "containerized": containerize}, fp) + + # Create log folder + logger.info("Local job manager started") + + +def exit(**kwargs): + try: + with open(SESSION_STATE_PATH, "r") as fp: + session_state = json.load(fp) + except FileNotFoundError: + logger.error(NO_JOB_MANAGER_MSG) + return + + redis_conn = get_redis_conn() + + # Mark all jobs as REMOVED and let the job manager terminate them properly. + job_details = redis_conn.hgetall(RedisHashKey.JOB_DETAILS) + if job_details: + for job_name, details in job_details.items(): + details = json.loads(details) + details["status"] = JobStatus.REMOVED + redis_conn.hset(RedisHashKey.JOB_DETAILS, job_name, json.dumps(details)) + logger.info(f"Gracefully terminating job {job_name.decode()}") + + # Stop job manager + close_by_pid(int(session_state["job_manager_pid"])) + + # Stop Redis + stop_redis(session_state["port"]) + + # Remove dump folder. + shutil.rmtree(LOCAL_ROOT, True) + + logger.info("Local job manager terminated.") + + +def add_job(conf_path: str, **kwargs): + redis_conn = get_redis_conn() + if not redis_conn: + return + + # Load job configuration file + with open(conf_path, "r") as fr: + conf = yaml.safe_load(fr) + + job_name = conf["job"] + if redis_conn.hexists(RedisHashKey.JOB_DETAILS, job_name): + logger.error(f"A job named '{job_name}' has already been added.") + return + + # Push job config to redis + redis_conn.hset(RedisHashKey.JOB_CONF, job_name, json.dumps(conf)) + details = { + "status": JobStatus.PENDING, + "added": time.time() + } + redis_conn.hset(RedisHashKey.JOB_DETAILS, job_name, json.dumps(details)) + + +def remove_jobs(job_names, **kwargs): + redis_conn = get_redis_conn() + if not redis_conn: + return + + for job_name in job_names: + details = redis_conn.hget(RedisHashKey.JOB_DETAILS, job_name) + if not details: + logger.error(f"No job named '{job_name}' has been scheduled or started.") + else: + details = json.loads(details) + details["status"] = JobStatus.REMOVED + redis_conn.hset(RedisHashKey.JOB_DETAILS, job_name, json.dumps(details)) + logger.info(f"Removed job {job_name}") + + +def describe_job(job_name, **kwargs): + redis_conn = get_redis_conn() + if not redis_conn: + return + + details = redis_conn.hget(RedisHashKey.JOB_DETAILS, job_name) + if not details: + logger.error(NO_JOB_MSG.format(job_name)) + return + + details = json.loads(details) + err = "error_message" in details + if err: + err_msg = details["error_message"].split('\n') + del details["error_message"] + + logger.info(details) + if err: + for line in err_msg: + logger.info(line) + + +def get_job_logs(job_name: str, tail: int = -1, **kwargs): + redis_conn = get_redis_conn() + if not redis_conn.hexists(RedisHashKey.JOB_CONF, job_name): + logger.error(NO_JOB_MSG.format(job_name)) + return + + conf = json.loads(redis_conn.hget(RedisHashKey.JOB_CONF, job_name)) + show_log(conf["log_path"], tail=tail) + + +def list_jobs(**kwargs): + redis_conn = get_redis_conn() + if not redis_conn: + return + + def get_time_diff_string(time_diff): + time_diff = int(time_diff) + days = time_diff // (3600 * 24) + if days: + return f"{days} days" + + hours = time_diff // 3600 + if hours: + return f"{hours} hours" + + minutes = time_diff // 60 + if minutes: + return f"{minutes} minutes" + + return f"{time_diff} seconds" + + # Header + logger.info(JOB_LS_TEMPLATE.format(JOB="JOB", STATUS="STATUS", STARTED="STARTED")) + for job_name, details in redis_conn.hgetall(RedisHashKey.JOB_DETAILS).items(): + job_name = job_name.decode() + details = json.loads(details) + if "start_time" in details: + time_diff = f"{get_time_diff_string(time.time() - details['start_time'])} ago" + logger.info(JOB_LS_TEMPLATE.format(JOB=job_name, STATUS=details["status"], STARTED=time_diff)) + else: + logger.info(JOB_LS_TEMPLATE.format(JOB=job_name, STATUS=details["status"], STARTED=JobStatus.PENDING)) diff --git a/maro/cli/local/job_manager.py b/maro/cli/local/job_manager.py new file mode 100644 index 000000000..d36ec7662 --- /dev/null +++ b/maro/cli/local/job_manager.py @@ -0,0 +1,94 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json +import os +import threading +import time + +import redis + +from maro.cli.local.utils import JobStatus, RedisHashKey, poll, start_rl_job, start_rl_job_in_containers, term +from maro.cli.utils.docker import build_image, image_exists +from maro.rl.workflows.config import ConfigParser + +if __name__ == "__main__": + redis_port = int(os.getenv("REDIS_PORT", default=19999)) + redis_conn = redis.Redis(host="localhost", port=redis_port) + started, max_running = {}, int(os.getenv("MAX_RUNNING", default=1)) + query_every = int(os.getenv("QUERY_EVERY", default=5)) + sigterm_timeout = int(os.getenv("SIGTERM_TIMEOUT", default=3)) + containerize = os.getenv("CONTAINERIZE", default="False") == "True" + local_maro_root = os.getenv("LOCAL_MARO_ROOT") + docker_file_path = os.getenv("DOCKERFILE_PATH") + docker_image_name = os.getenv("DOCKER_IMAGE_NAME") + + # thread to monitor a job + def monitor(job_name): + removed, error, err_out, running = False, False, None, started[job_name] + while running: + error, err_out, running = poll(running) + # check if the job has been marked as REMOVED before termination + details = json.loads(redis_conn.hget(RedisHashKey.JOB_DETAILS, job_name)) + if details["status"] == JobStatus.REMOVED: + removed = True + break + + if error: + break + + if removed: + term(started[job_name], job_name, timeout=sigterm_timeout) + redis_conn.hdel(RedisHashKey.JOB_DETAILS, job_name) + redis_conn.hdel(RedisHashKey.JOB_CONF, job_name) + return + + if error: + term(started[job_name], job_name, timeout=sigterm_timeout) + details["status"] = JobStatus.ERROR + details["error_message"] = err_out + redis_conn.hset(RedisHashKey.JOB_DETAILS, job_name, json.dumps(details)) + else: # all job processes terminated normally + details["status"] = JobStatus.FINISHED + redis_conn.hset(RedisHashKey.JOB_DETAILS, job_name, json.dumps(details)) + + # Continue to monitor if the job is marked as REMOVED + while json.loads(redis_conn.hget(RedisHashKey.JOB_DETAILS, job_name))["status"] != JobStatus.REMOVED: + time.sleep(query_every) + + term(started[job_name], job_name, timeout=sigterm_timeout) + redis_conn.hdel(RedisHashKey.JOB_DETAILS, job_name) + redis_conn.hdel(RedisHashKey.JOB_CONF, job_name) + + while True: + # check for pending jobs + job_details = redis_conn.hgetall(RedisHashKey.JOB_DETAILS) + if job_details: + num_running, pending = 0, [] + for job_name, details in job_details.items(): + job_name, details = job_name.decode(), json.loads(details) + if details["status"] == JobStatus.RUNNING: + num_running += 1 + elif details["status"] == JobStatus.PENDING: + pending.append((job_name, json.loads(redis_conn.hget(RedisHashKey.JOB_CONF, job_name)))) + + for job_name, conf in pending[:max(0, max_running - num_running)]: + if containerize and not image_exists(docker_image_name): + redis_conn.hset( + RedisHashKey.JOB_DETAILS, job_name, json.dumps({"status": JobStatus.IMAGE_BUILDING}) + ) + build_image(local_maro_root, docker_file_path, docker_image_name) + + parser = ConfigParser(conf) + if containerize: + path_mapping = parser.get_path_mapping(containerize=True) + started[job_name] = start_rl_job_in_containers(parser, docker_image_name) + details["containers"] = started[job_name] + else: + started[job_name] = start_rl_job(parser, local_maro_root, background=True) + details["pids"] = [proc.pid for proc in started[job_name]] + details = {"status": JobStatus.RUNNING, "start_time": time.time()} + redis_conn.hset(RedisHashKey.JOB_DETAILS, job_name, json.dumps(details)) + threading.Thread(target=monitor, args=(job_name,)).start() # start job monitoring thread + + time.sleep(query_every) diff --git a/maro/cli/local/utils.py b/maro/cli/local/utils.py new file mode 100644 index 000000000..643f780c2 --- /dev/null +++ b/maro/cli/local/utils.py @@ -0,0 +1,195 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import subprocess +from copy import deepcopy +from typing import List + +import docker +import yaml + +from maro.cli.utils.common import format_env_vars +from maro.rl.workflows.config.parser import ConfigParser + + +class RedisHashKey: + """Record Redis elements name, and only for maro process""" + JOB_CONF = "job_conf" + JOB_DETAILS = "job_details" + + +class JobStatus: + PENDING = "pending" + IMAGE_BUILDING = "image_building" + RUNNING = "running" + ERROR = "error" + REMOVED = "removed" + FINISHED = "finished" + + +def start_redis(port: int): + subprocess.Popen(["redis-server", "--port", str(port)], stdout=subprocess.DEVNULL) + + +def stop_redis(port: int): + subprocess.Popen(["redis-cli", "-p", str(port), "shutdown"], stdout=subprocess.DEVNULL, stderr=subprocess.PIPE) + + +def extract_error_msg_from_docker_log(container: docker.models.containers.Container): + logs = container.logs().decode().splitlines() + for i, log in enumerate(logs): + if "Traceback (most recent call last):" in log: + return "\n".join(logs[i:]) + + return logs + + +def check_proc_status(proc): + if isinstance(proc, subprocess.Popen): + if proc.poll() is None: + return True, 0, None + _, err_out = proc.communicate() + return False, proc.returncode, err_out + else: + client = docker.from_env() + container_state = client.api.inspect_container(proc.id)["State"] + return container_state["Running"], container_state["ExitCode"], extract_error_msg_from_docker_log(proc) + + +def poll(procs): + error, running = False, [] + for proc in procs: + is_running, exit_code, err_out = check_proc_status(proc) + if is_running: + running.append(proc) + elif exit_code: + error = True + break + + return error, err_out, running + + +def term(procs, job_name: str, timeout: int = 3): + if isinstance(procs[0], subprocess.Popen): + for proc in procs: + if proc.poll() is None: + try: + proc.terminate() + proc.wait(timeout=timeout) + except subprocess.TimeoutExpired: + proc.kill() + else: + for proc in procs: + try: + proc.stop(timeout=timeout) + proc.remove() + except Exception: + pass + + client = docker.from_env() + try: + job_network = client.networks.get(job_name) + job_network.remove() + except Exception: + pass + + +def exec(cmd: str, env: dict, debug: bool = False) -> subprocess.Popen: + stream = None if debug else subprocess.PIPE + return subprocess.Popen( + cmd.split(), env={**os.environ.copy(), **env}, stdout=stream, stderr=stream, encoding="utf8" + ) + + +def start_rl_job( + parser: ConfigParser, maro_root: str, evaluate_only: bool, background: bool = False, +) -> List[subprocess.Popen]: + procs = [ + exec( + f"python {script}" + ("" if not evaluate_only else " --evaluate_only"), + format_env_vars({**env, "PYTHONPATH": maro_root}, mode="proc"), + debug=not background + ) + for script, env in parser.get_job_spec().values() + ] + if not background: + for proc in procs: + proc.communicate() + + return procs + + +def start_rl_job_in_containers(parser: ConfigParser, image_name: str) -> list: + job_name = parser.config["job"] + client, containers = docker.from_env(), [] + training_mode = parser.config["training"]["mode"] + if "parallelism" in parser.config["rollout"]: + rollout_parallelism = max( + parser.config["rollout"]["parallelism"]["sampling"], + parser.config["rollout"]["parallelism"].get("eval", 1) + ) + else: + rollout_parallelism = 1 + if training_mode != "simple" or rollout_parallelism > 1: + # create the exclusive network for the job + client.networks.create(job_name, driver="bridge") + + for component, (script, env) in parser.get_job_spec(containerize=True).items(): + # volume mounts for scenario folder, policy loading, checkpointing and logging + container = client.containers.run( + image_name, + command=f"python3 {script}", + detach=True, + name=component, + environment=env, + volumes=[f"{src}:{dst}" for src, dst in parser.get_path_mapping(containerize=True).items()], + network=job_name + ) + + containers.append(container) + + return containers + + +def get_docker_compose_yml_path(maro_root: str) -> str: + return os.path.join(maro_root, ".tmp", "docker-compose.yml") + + +def start_rl_job_with_docker_compose( + parser: ConfigParser, context: str, dockerfile_path: str, image_name: str, evaluate_only: bool, +) -> None: + common_spec = { + "build": {"context": context, "dockerfile": dockerfile_path}, + "image": image_name, + "volumes": [f"./{src}:{dst}" for src, dst in parser.get_path_mapping(containerize=True).items()] + } + + job_name = parser.config["job"] + manifest = { + "version": "3.9", + "services": { + component: { + **deepcopy(common_spec), + **{ + "container_name": component, + "command": f"python3 {script}" + ("" if not evaluate_only else " --evaluate_only"), + "environment": format_env_vars(env, mode="docker-compose") + } + } + for component, (script, env) in parser.get_job_spec(containerize=True).items() + }, + } + + docker_compose_file_path = get_docker_compose_yml_path(maro_root=context) + with open(docker_compose_file_path, "w") as fp: + yaml.safe_dump(manifest, fp) + + subprocess.run( + ["docker-compose", "--project-name", job_name, "-f", docker_compose_file_path, "up", "--remove-orphans"] + ) + + +def stop_rl_job_with_docker_compose(job_name: str, context: str): + subprocess.run(["docker-compose", "--project-name", job_name, "down"]) + os.remove(get_docker_compose_yml_path(maro_root=context)) diff --git a/maro/cli/maro.py b/maro/cli/maro.py index d4b8aabd6..b77a807f1 100644 --- a/maro/cli/maro.py +++ b/maro/cli/maro.py @@ -90,6 +90,15 @@ def main(): parser_k8s.set_defaults(func=_help_func(parser=parser_k8s)) load_parser_k8s(prev_parser=parser_k8s, global_parser=global_parser) + # maro aks + parser_aks = subparsers.add_parser( + "aks", + help="Manage distributed cluster with Kubernetes.", + parents=[global_parser] + ) + parser_aks.set_defaults(func=_help_func(parser=parser_aks)) + load_parser_aks(prev_parser=parser_aks, global_parser=global_parser) + # maro inspector parser_inspector = subparsers.add_parser( 'inspector', @@ -99,13 +108,13 @@ def main(): parser_inspector.set_defaults(func=_help_func(parser=parser_inspector)) load_parser_inspector(parser_inspector, global_parser) - # maro process - parser_process = subparsers.add_parser( - "process", - help="Run application by mulit-process to simulate distributed mode." + # maro local + parser_local = subparsers.add_parser( + "local", + help="Run jobs locally." ) - parser_process.set_defaults(func=_help_func(parser=parser_process)) - load_parser_process(prev_parser=parser_process, global_parser=global_parser) + parser_local.set_defaults(func=_help_func(parser=parser_local)) + load_parser_local(prev_parser=parser_local, global_parser=global_parser) # maro project parser_project = subparsers.add_parser( @@ -151,152 +160,128 @@ def main(): logger.error_red(f"{e.__class__.__name__}: {e.get_message()}") -def load_parser_process(prev_parser: ArgumentParser, global_parser: ArgumentParser) -> None: +def load_parser_local(prev_parser: ArgumentParser, global_parser: ArgumentParser) -> None: subparsers = prev_parser.add_subparsers() - # maro process create - from maro.cli.process.create import create - parser_setup = subparsers.add_parser( - "create", - help="Create local process environment.", + # maro local run + from maro.cli.local.commands import run + parser = subparsers.add_parser( + "run", + help="Run a job in debug mode.", examples=CliExamples.MARO_PROCESS_SETUP, parents=[global_parser] ) - parser_setup.add_argument( - 'deployment_path', - help='Path of the local process setting deployment.', - nargs='?', - default=None) - parser_setup.set_defaults(func=create) + parser.add_argument("conf_path", help='Path of the job deployment') + parser.add_argument("-c", "--containerize", action="store_true", help="Whether to run jobs in containers") + parser.add_argument("--evaluate_only", action="store_true", help="Only run evaluation part of the workflow") + parser.add_argument("-p", "--port", type=int, default=20000, help="") + parser.set_defaults(func=run) - # maro process delete - from maro.cli.process.delete import delete - parser_setup = subparsers.add_parser( - "delete", - help="Delete the local process environment. Including closing agents and maro Redis.", + # maro local init + from maro.cli.local.commands import init + parser = subparsers.add_parser( + "init", + help="Initialize local job manager.", + examples=CliExamples.MARO_PROCESS_SETUP, + parents=[global_parser] + ) + parser.add_argument( + "-p", "--port", type=int, default=19999, + help="Port on local machine to launch the Redis server at. Defaults to 19999." + ) + parser.add_argument( + "-m", "--max-running", type=int, default=3, + help="Maximum number of jobs to allow running at the same time. Defaults to 3." + ) + parser.add_argument( + "-q", "--query-every", type=int, default=5, + help="Number of seconds to wait between queries to the Redis server for pending or removed jobs. Defaults to 5." + ) + parser.add_argument( + "-t", "--timeout", type=int, default=3, + help=""" + Number of seconds to wait after sending SIGTERM to a process. If the process does not terminate + during this time, the process will be force-killed through SIGKILL. Defaults to 3. + """ + ) + parser.add_argument("-c", "--containerize", action="store_true", help="Whether to run jobs in containers") + parser.set_defaults(func=init) + + # maro local exit + from maro.cli.local.commands import exit + parser = subparsers.add_parser( + "exit", + help="Terminate the local job manager", parents=[global_parser] ) - parser_setup.set_defaults(func=delete) + parser.set_defaults(func=exit) - # maro process job - parser_job = subparsers.add_parser( + # maro local job + parser = subparsers.add_parser( "job", help="Manage jobs", parents=[global_parser] ) - parser_job.set_defaults(func=_help_func(parser=parser_job)) - parser_job_subparsers = parser_job.add_subparsers() + parser.set_defaults(func=_help_func(parser=parser)) + job_subparsers = parser.add_subparsers() - # maro process job start - from maro.cli.process.job import start_job - parser_job_start = parser_job_subparsers.add_parser( - 'start', - help='Start a training job', + # maro local job add + from maro.cli.local.commands import add_job + job_add_parser = job_subparsers.add_parser( + "add", + help="Start an RL job", examples=CliExamples.MARO_PROCESS_JOB_START, parents=[global_parser] ) - parser_job_start.add_argument( - 'deployment_path', help='Path of the job deployment') - parser_job_start.set_defaults(func=start_job) + job_add_parser.add_argument("conf_path", help='Path of the job deployment') + job_add_parser.set_defaults(func=add_job) - # maro process job stop - from maro.cli.process.job import stop_job - parser_job_stop = parser_job_subparsers.add_parser( - 'stop', - help='Stop a training job', + # maro local job rm + from maro.cli.local.commands import remove_jobs + job_stop_parser = job_subparsers.add_parser( + "rm", + help='Stop an RL job', examples=CliExamples.MARO_PROCESS_JOB_STOP, parents=[global_parser] ) - parser_job_stop.add_argument( - 'job_name', help='Name of the job') - parser_job_stop.set_defaults(func=stop_job) + job_stop_parser.add_argument('job_names', help="Job names", nargs="*") + job_stop_parser.set_defaults(func=remove_jobs) - # maro process job delete - from maro.cli.process.job import delete_job - parser_job_delete = parser_job_subparsers.add_parser( - 'delete', - help='delete a stopped job', - examples=CliExamples.MARO_PROCESS_JOB_DELETE, + # maro local job describe + from maro.cli.local.commands import describe_job + job_stop_parser = job_subparsers.add_parser( + "describe", + help="Get the status of an RL job and the error information if the job fails due to some error", + examples=CliExamples.MARO_PROCESS_JOB_STOP, parents=[global_parser] ) - parser_job_delete.add_argument( - 'job_name', help='Name of the job or the schedule') - parser_job_delete.set_defaults(func=delete_job) + job_stop_parser.add_argument('job_name', help='Job name') + job_stop_parser.set_defaults(func=describe_job) - # maro process job list - from maro.cli.process.job import list_jobs - parser_job_list = parser_job_subparsers.add_parser( - 'list', + # maro local job ls + from maro.cli.local.commands import list_jobs + job_list_parser = job_subparsers.add_parser( + "ls", help='List all jobs', examples=CliExamples.MARO_PROCESS_JOB_LIST, parents=[global_parser] ) - parser_job_list.set_defaults(func=list_jobs) + job_list_parser.set_defaults(func=list_jobs) - # maro process job logs - from maro.cli.process.job import get_job_logs - parser_job_logs = parser_job_subparsers.add_parser( - 'logs', - help='Get logs of the job', + # maro local job logs + from maro.cli.local.commands import get_job_logs + job_logs_parser = job_subparsers.add_parser( + "logs", + help="Get job logs", examples=CliExamples.MARO_PROCESS_JOB_LOGS, parents=[global_parser] ) - parser_job_logs.add_argument( - 'job_name', help='Name of the job') - parser_job_logs.set_defaults(func=get_job_logs) - - # maro process schedule - parser_schedule = subparsers.add_parser( - 'schedule', - help='Manage schedules', - parents=[global_parser] - ) - parser_schedule.set_defaults(func=_help_func(parser=parser_schedule)) - parser_schedule_subparsers = parser_schedule.add_subparsers() - - # maro process schedule start - from maro.cli.process.schedule import start_schedule - parser_schedule_start = parser_schedule_subparsers.add_parser( - 'start', - help='Start a schedule', - examples=CliExamples.MARO_PROCESS_SCHEDULE_START, - parents=[global_parser] + job_logs_parser.add_argument("job_name", help="job name") + job_logs_parser.add_argument( + "-n", "--tail", type=int, default=-1, + help="Number of lines to show from the end of the given job's logs" ) - parser_schedule_start.add_argument( - 'deployment_path', help='Path of the schedule deployment') - parser_schedule_start.set_defaults(func=start_schedule) - - # maro process schedule stop - from maro.cli.process.schedule import stop_schedule - parser_schedule_stop = parser_schedule_subparsers.add_parser( - 'stop', - help='Stop a schedule', - examples=CliExamples.MARO_PROCESS_SCHEDULE_STOP, - parents=[global_parser] - ) - parser_schedule_stop.add_argument( - 'schedule_name', help='Name of the schedule') - parser_schedule_stop.set_defaults(func=stop_schedule) - - # maro process template - from maro.cli.process.template import template - parser_template = subparsers.add_parser( - "template", - help="Get deployment templates", - examples=CliExamples.MARO_PROCESS_TEMPLATE, - parents=[global_parser] - ) - parser_template.add_argument( - "--setting_deploy", - action="store_true", - help="Get environment setting templates" - ) - parser_template.add_argument( - "export_path", - default="./", - nargs='?', - help="Path of the export directory") - parser_template.set_defaults(func=template) + job_logs_parser.set_defaults(func=get_job_logs) def load_parser_grass(prev_parser: ArgumentParser, global_parser: ArgumentParser) -> None: @@ -922,6 +907,81 @@ def load_parser_k8s(prev_parser: ArgumentParser, global_parser: ArgumentParser) parser_template.set_defaults(func=template) +def load_parser_aks(prev_parser: ArgumentParser, global_parser: ArgumentParser) -> None: + subparsers = prev_parser.add_subparsers() + + # maro aks create + from maro.cli.k8s.aks.aks_commands import init + parser_create = subparsers.add_parser( + "init", + help=""" + Deploy resources and start required services on Azure. The configuration file template can be found + in cli/k8s/aks/conf.yml. Use the Azure CLI to log into your Azure account (az login ...) and the the + Azure Container Registry (az acr login ...) first. + """, + examples=CliExamples.MARO_K8S_CREATE, + parents=[global_parser] + ) + parser_create.add_argument("deployment_conf_path", help="Path of the deployment configuration file") + parser_create.set_defaults(func=init) + + # maro aks exit + from maro.cli.k8s.aks.aks_commands import exit + parser_create = subparsers.add_parser( + "exit", + help="Delete deployed resources", + examples=CliExamples.MARO_K8S_DELETE, + parents=[global_parser] + ) + parser_create.set_defaults(func=exit) + + # maro aks job + parser_job = subparsers.add_parser( + "job", + help="AKS job-related commands", + parents=[global_parser] + ) + parser_job.set_defaults(func=_help_func(parser=parser_job)) + job_subparsers = parser_job.add_subparsers() + + # maro aks job add + from maro.cli.k8s.aks.aks_commands import add_job + parser_job_start = job_subparsers.add_parser( + "add", + help="Add an RL job to the AKS cluster", + examples=CliExamples.MARO_K8S_JOB_START, + parents=[global_parser] + ) + parser_job_start.add_argument("conf_path", help="Path to the job configuration file") + parser_job_start.set_defaults(func=add_job) + + # maro aks job rm + from maro.cli.k8s.aks.aks_commands import remove_jobs + parser_job_start = job_subparsers.add_parser( + "rm", + help="Remove previously scheduled RL jobs from the AKS cluster", + examples=CliExamples.MARO_K8S_JOB_START, + parents=[global_parser] + ) + parser_job_start.add_argument("job_names", help="Name of job to be removed", nargs="*") + parser_job_start.set_defaults(func=remove_jobs) + + # maro aks job logs + from maro.cli.k8s.aks.aks_commands import get_job_logs + job_logs_parser = job_subparsers.add_parser( + "logs", + help="Get job logs", + examples=CliExamples.MARO_PROCESS_JOB_LOGS, + parents=[global_parser] + ) + job_logs_parser.add_argument("job_name", help="job name") + job_logs_parser.add_argument( + "-n", "--tail", type=int, default=-1, + help="Number of lines to show from the end of the given job's logs" + ) + job_logs_parser.set_defaults(func=get_job_logs) + + def load_parser_data(prev_parser: ArgumentParser, global_parser: ArgumentParser): data_cmd_sub_parsers = prev_parser.add_subparsers() diff --git a/maro/cli/process/agent/job_agent.py b/maro/cli/process/agent/job_agent.py deleted file mode 100644 index 57414e279..000000000 --- a/maro/cli/process/agent/job_agent.py +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import json -import multiprocessing as mp -import os -import subprocess -import time - -import psutil -import redis - -from maro.cli.grass.lib.services.utils.params import JobStatus -from maro.cli.process.utils.details import close_by_pid, get_child_pid -from maro.cli.utils.details_reader import DetailsReader -from maro.cli.utils.params import LocalPaths, ProcessRedisName - - -class PendingJobAgent(mp.Process): - def __init__(self, cluster_detail: dict, redis_connection, check_interval: int = 60): - super().__init__() - self.cluster_detail = cluster_detail - self.redis_connection = redis_connection - self.check_interval = check_interval - - def run(self): - while True: - self._check_pending_ticket() - time.sleep(self.check_interval) - - def _check_pending_ticket(self): - # Check pending job ticket - pending_jobs = self.redis_connection.lrange(ProcessRedisName.PENDING_JOB_TICKETS, 0, -1) - running_jobs_length = len(JobTrackingAgent.get_running_jobs( - self.redis_connection.hgetall(ProcessRedisName.JOB_DETAILS) - )) - parallel_level = self.cluster_detail["parallel_level"] - - for job_name in pending_jobs: - job_detail = json.loads(self.redis_connection.hget(ProcessRedisName.JOB_DETAILS, job_name)) - # Start pending job only if current running job's number less than parallel level. - if int(parallel_level) > running_jobs_length: - self._start_job(job_detail) - self.redis_connection.lrem(ProcessRedisName.PENDING_JOB_TICKETS, 0, job_name) - running_jobs_length += 1 - - def _start_job(self, job_details: dict): - command_pid_list = [] - for component_type, command_info in job_details["components"].items(): - component_number = command_info["num"] - component_command = f"JOB_NAME={job_details['name']} " + command_info["command"] - for number in range(component_number): - job_local_path = os.path.expanduser(f"{LocalPaths.MARO_PROCESS}/{job_details['name']}") - if not os.path.exists(job_local_path): - os.makedirs(job_local_path) - - with open(f"{job_local_path}/{component_type}_{number}.log", "w") as log_file: - proc = subprocess.Popen(component_command, shell=True, stdout=log_file) - command_pid = get_child_pid(proc.pid) - if not command_pid: - command_pid_list.append(proc.pid) - else: - command_pid_list.append(command_pid) - - job_details["status"] = JobStatus.RUNNING - job_details["pid_list"] = command_pid_list - self.redis_connection.hset(ProcessRedisName.JOB_DETAILS, job_details["name"], json.dumps(job_details)) - - -class JobTrackingAgent(mp.Process): - def __init__(self, cluster_detail: dict, redis_connection, check_interval: int = 60): - super().__init__() - self.cluster_detail = cluster_detail - self.redis_connection = redis_connection - self.check_interval = check_interval - self._shutdown_count = 0 - self._countdown = cluster_detail["agent_countdown"] - - def run(self): - while True: - self._check_job_status() - time.sleep(self.check_interval) - keep_alive = self.cluster_detail["keep_agent_alive"] - if not keep_alive: - self._close_agents() - - def _check_job_status(self): - running_jobs = self.get_running_jobs(self.redis_connection.hgetall(ProcessRedisName.JOB_DETAILS)) - - for running_job_name, running_job_detail in running_jobs.items(): - # Check pid status - still_alive = False - for pid in running_job_detail["pid_list"]: - if psutil.pid_exists(pid): - still_alive = True - - # Update if no pid exists - if not still_alive: - running_job_detail["status"] = JobStatus.FINISH - del running_job_detail["pid_list"] - self.redis_connection.hset( - ProcessRedisName.JOB_DETAILS, - running_job_name, - json.dumps(running_job_detail) - ) - - @staticmethod - def get_running_jobs(job_details: dict): - running_jobs = {} - - for job_name, job_detail in job_details.items(): - job_detail = json.loads(job_detail) - if job_detail["status"] == JobStatus.RUNNING: - running_jobs[job_name.decode()] = job_detail - - return running_jobs - - def _close_agents(self): - if ( - not len( - JobTrackingAgent.get_running_jobs(self.redis_connection.hgetall(ProcessRedisName.JOB_DETAILS)) - ) and - not self.redis_connection.llen(ProcessRedisName.PENDING_JOB_TICKETS) - ): - self._shutdown_count += 1 - else: - self._shutdown_count = 0 - - if self._shutdown_count >= self._countdown: - agent_pid = int(self.redis_connection.hget(ProcessRedisName.SETTING, "agent_pid")) - - # close agent - close_by_pid(pid=agent_pid, recursive=True) - - # Set agent status to 0 - self.redis_connection.hset(ProcessRedisName.SETTING, "agent_status", 0) - - -class KilledJobAgent(mp.Process): - def __init__(self, cluster_detail: dict, redis_connection, check_interval: int = 60): - super().__init__() - self.cluster_detail = cluster_detail - self.redis_connection = redis_connection - self.check_interval = check_interval - - def run(self): - while True: - self._check_killed_tickets() - time.sleep(self.check_interval) - - def _check_killed_tickets(self): - # Check pending job ticket - killed_job_names = self.redis_connection.lrange(ProcessRedisName.KILLED_JOB_TICKETS, 0, -1) - - for job_name in killed_job_names: - job_detail = json.loads(self.redis_connection.hget(ProcessRedisName.JOB_DETAILS, job_name)) - if job_detail["status"] == JobStatus.RUNNING: - close_by_pid(pid=job_detail["pid_list"], recursive=False) - del job_detail["pid_list"] - elif job_detail["status"] == JobStatus.PENDING: - self.redis_connection.lrem(ProcessRedisName.PENDING_JOB_TICKETS, 0, job_name) - elif job_detail["status"] == JobStatus.FINISH: - continue - - job_detail["status"] = JobStatus.KILLED - self.redis_connection.hset(ProcessRedisName.JOB_DETAILS, job_name, json.dumps(job_detail)) - self.redis_connection.lrem(ProcessRedisName.KILLED_JOB_TICKETS, 0, job_name) - - -class MasterAgent: - def __init__(self): - self.cluster_detail = DetailsReader.load_cluster_details("process") - self.check_interval = self.cluster_detail["check_interval"] - self.redis_connection = redis.Redis( - host=self.cluster_detail["redis_info"]["host"], - port=self.cluster_detail["redis_info"]["port"] - ) - self.redis_connection.hset(ProcessRedisName.SETTING, "agent_pid", os.getpid()) - - def start(self) -> None: - """Start agents.""" - pending_job_agent = PendingJobAgent( - cluster_detail=self.cluster_detail, - redis_connection=self.redis_connection, - check_interval=self.check_interval - ) - pending_job_agent.start() - - killed_job_agent = KilledJobAgent( - cluster_detail=self.cluster_detail, - redis_connection=self.redis_connection, - check_interval=self.check_interval - ) - killed_job_agent.start() - - job_tracking_agent = JobTrackingAgent( - cluster_detail=self.cluster_detail, - redis_connection=self.redis_connection, - check_interval=self.check_interval - ) - job_tracking_agent.start() - - -if __name__ == "__main__": - master_agent = MasterAgent() - master_agent.start() diff --git a/maro/cli/process/agent/resource_agent.py b/maro/cli/process/agent/resource_agent.py deleted file mode 100644 index 66c5d201a..000000000 --- a/maro/cli/process/agent/resource_agent.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import json -import multiprocessing as mp -import os -import time - -import redis - -from maro.cli.utils.params import LocalParams -from maro.cli.utils.resource_executor import ResourceInfo -from maro.utils.exception.cli_exception import BadRequestError - - -class ResourceTrackingAgent(mp.Process): - def __init__( - self, - check_interval: int = 30 - ): - super().__init__() - self._redis_connection = redis.Redis(host="localhost", port=LocalParams.RESOURCE_REDIS_PORT) - try: - if self._redis_connection.hexists(LocalParams.RESOURCE_INFO, "check_interval"): - self._check_interval = int(self._redis_connection.hget(LocalParams.RESOURCE_INFO, "check_interval")) - else: - self._check_interval = check_interval - except Exception: - raise BadRequestError( - "Failure to connect to Resource Redis." - "Please make sure at least one cluster running." - ) - - self._set_resource_info() - - def _set_resource_info(self): - # Set resource agent pid. - self._redis_connection.hset( - LocalParams.RESOURCE_INFO, - "agent_pid", - os.getpid() - ) - - # Set resource agent check interval. - self._redis_connection.hset( - LocalParams.RESOURCE_INFO, - "check_interval", - json.dumps(self._check_interval) - ) - - # Push static resource information into Redis. - resource = ResourceInfo.get_static_info() - self._redis_connection.hset( - LocalParams.RESOURCE_INFO, - "resource", - json.dumps(resource) - ) - - def run(self) -> None: - """Start tracking node status and updating details. - - Returns: - None. - """ - while True: - start_time = time.time() - self.push_local_resource_usage() - time.sleep(max(self._check_interval - (time.time() - start_time), 0)) - - self._check_interval = int(self._redis_connection.hget(LocalParams.RESOURCE_INFO, "check_interval")) - - def push_local_resource_usage(self): - resource_usage = ResourceInfo.get_dynamic_info(self._check_interval) - - self._redis_connection.rpush( - LocalParams.CPU_USAGE, - json.dumps(resource_usage["cpu_usage_per_core"]) - ) - - self._redis_connection.rpush( - LocalParams.MEMORY_USAGE, - json.dumps(resource_usage["memory_usage"]) - ) - - self._redis_connection.rpush( - LocalParams.GPU_USAGE, - json.dumps(resource_usage["gpu_memory_usage"]) - ) - - -if __name__ == "__main__": - resource_agent = ResourceTrackingAgent() - resource_agent.start() diff --git a/maro/cli/process/create.py b/maro/cli/process/create.py deleted file mode 100644 index 467b2db20..000000000 --- a/maro/cli/process/create.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import yaml - -from maro.cli.process.executor import ProcessExecutor -from maro.cli.process.utils.default_param import process_setting - - -def create(deployment_path: str, **kwargs): - if deployment_path is not None: - with open(deployment_path, "r") as fr: - create_deployment = yaml.safe_load(fr) - else: - create_deployment = process_setting - - executor = ProcessExecutor(create_deployment) - executor.create() diff --git a/maro/cli/process/delete.py b/maro/cli/process/delete.py deleted file mode 100644 index 8ede6a3ae..000000000 --- a/maro/cli/process/delete.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from maro.cli.process.executor import ProcessExecutor - - -def delete(**kwargs): - executor = ProcessExecutor() - executor.delete() diff --git a/maro/cli/process/deployment/process_job_deployment.yml b/maro/cli/process/deployment/process_job_deployment.yml deleted file mode 100644 index 7720e9b40..000000000 --- a/maro/cli/process/deployment/process_job_deployment.yml +++ /dev/null @@ -1,10 +0,0 @@ -mode: process -name: MyJobName # str: name of the training job - -components: # component config - actor: - num: 5 # int: number of this component - command: "python /target/path/run_actor.py" # str: command to be executed - learner: - num: 1 - command: "python /target/path/run_learner.py" diff --git a/maro/cli/process/deployment/process_schedule_deployment.yml b/maro/cli/process/deployment/process_schedule_deployment.yml deleted file mode 100644 index d1d7769c2..000000000 --- a/maro/cli/process/deployment/process_schedule_deployment.yml +++ /dev/null @@ -1,16 +0,0 @@ -mode: process -name: MyScheduleName # str: name of the training schedule - -job_names: # list: names of the training job - - MyJobName2 - - MyJobName3 - - MyJobName4 - - MyJobName5 - -components: # component config - actor: - num: 5 # int: number of this component - command: "python /target/path/run_actor.py" # str: command to be executed - learner: - num: 1 - command: "python /target/path/run_learner.py" diff --git a/maro/cli/process/deployment/process_setting_deployment.yml b/maro/cli/process/deployment/process_setting_deployment.yml deleted file mode 100644 index 3f9168350..000000000 --- a/maro/cli/process/deployment/process_setting_deployment.yml +++ /dev/null @@ -1,8 +0,0 @@ -redis_info: - host: "localhost" - port: 19999 -redis_mode: MARO # one of MARO, customized. customized Redis won't be exited after maro process clear. -parallel_level: 1 # Represented the maximum number of running jobs in the same times. -keep_agent_alive: True # If True represented the agents won't exit until the environment delete; otherwise, False. -agent_countdown: 5 # After agent_countdown times checks, still no jobs will close agents. Available only if keep_agent_alive is 0. -check_interval: 60 # The time interval (seconds) of agents check with Redis diff --git a/maro/cli/process/executor.py b/maro/cli/process/executor.py deleted file mode 100644 index b3823fc77..000000000 --- a/maro/cli/process/executor.py +++ /dev/null @@ -1,248 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import copy -import json -import os -import shutil -import subprocess - -import redis -import yaml - -from maro.cli.grass.lib.services.utils.params import JobStatus -from maro.cli.process.utils.details import close_by_pid, get_redis_pid_by_port -from maro.cli.utils.abs_visible_executor import AbsVisibleExecutor -from maro.cli.utils.details_reader import DetailsReader -from maro.cli.utils.details_writer import DetailsWriter -from maro.cli.utils.params import GlobalPaths, LocalPaths, ProcessRedisName -from maro.cli.utils.resource_executor import LocalResourceExecutor -from maro.utils.logger import CliLogger - -logger = CliLogger(name=__name__) - - -class ProcessExecutor(AbsVisibleExecutor): - def __init__(self, details: dict = None): - self.details = details if details else \ - DetailsReader.load_cluster_details("process") - - # Connection with Redis - redis_port = self.details["redis_info"]["port"] - self._redis_connection = redis.Redis(host="localhost", port=redis_port) - try: - self._redis_connection.ping() - except Exception: - redis_process = subprocess.Popen( - ["redis-server", "--port", str(redis_port), "--daemonize yes"] - ) - redis_process.wait(timeout=2) - - # Connection with Resource Redis - self._resource_redis = LocalResourceExecutor() - - def create(self): - logger.info("Starting MARO Multi-Process Mode.") - if os.path.isdir(f"{GlobalPaths.ABS_MARO_CLUSTERS}/process"): - logger.warning("Process mode has been created.") - - # Get environment setting - DetailsWriter.save_cluster_details( - cluster_name="process", - cluster_details=self.details - ) - - # Start agents - command = f"python {LocalPaths.MARO_PROCESS_AGENT}" - _ = subprocess.Popen(command, shell=True) - self._redis_connection.hset(ProcessRedisName.SETTING, "agent_status", 1) - - # Add connection to resource Redis. - self._resource_redis.add_cluster() - - logger.info(f"MARO process mode setting: {self.details}") - - def delete(self): - process_setting = self._redis_connection.hgetall(ProcessRedisName.SETTING) - process_setting = { - key.decode(): json.loads(value) for key, value in process_setting.items() - } - - # Stop running jobs - jobs = self._redis_connection.hgetall(ProcessRedisName.JOB_DETAILS) - if jobs: - for job_name, job_detail in jobs.items(): - job_detail = json.loads(job_detail) - if job_detail["status"] == JobStatus.RUNNING: - close_by_pid(pid=job_detail["pid_list"], recursive=False) - logger.info(f"Stop running job {job_name.decode()}.") - - # Stop agents - agent_status = int(process_setting["agent_status"]) - if agent_status: - agent_pid = int(process_setting["agent_pid"]) - close_by_pid(pid=agent_pid, recursive=True) - logger.info("Close agents.") - else: - logger.info("Agents is already closed.") - - # Stop Redis or clear Redis - redis_mode = self.details["redis_mode"] - if redis_mode == "MARO": - redis_pid = get_redis_pid_by_port(self.details["redis_info"]["port"]) - close_by_pid(pid=redis_pid, recursive=False) - else: - self._redis_clear() - - # Rm connection from resource redis. - self._resource_redis.sub_cluster() - - logger.info("Redis cleared.") - - # Remove local process file. - shutil.rmtree(f"{GlobalPaths.ABS_MARO_CLUSTERS}/process", True) - logger.info("Process mode has been deleted.") - - def _redis_clear(self): - redis_keys = self._redis_connection.keys("process:*") - for key in redis_keys: - self._redis_connection.delete(key) - - def start_job(self, deployment_path: str): - # Load start_job_deployment - with open(deployment_path, "r") as fr: - start_job_deployment = yaml.safe_load(fr) - - job_name = start_job_deployment["name"] - start_job_deployment["status"] = JobStatus.PENDING - # Push job details to redis - self._redis_connection.hset( - ProcessRedisName.JOB_DETAILS, - job_name, - json.dumps(start_job_deployment) - ) - - self._push_pending_job(job_name) - - def _push_pending_job(self, job_name: str): - # Push job name to pending_job_tickets - self._redis_connection.lpush( - ProcessRedisName.PENDING_JOB_TICKETS, - job_name - ) - logger.info(f"Sending {job_name} into pending job tickets.") - - def stop_job(self, job_name: str): - if not self._redis_connection.hexists(ProcessRedisName.JOB_DETAILS, job_name): - logger.error(f"No such job '{job_name}' in Redis.") - return - - # push job_name into kill_job_tickets - self._redis_connection.lpush( - ProcessRedisName.KILLED_JOB_TICKETS, - job_name - ) - logger.info(f"Sending {job_name} into killed job tickets.") - - def delete_job(self, job_name: str): - # Stop job for running and pending job. - self.stop_job(job_name) - - # Rm job details in Redis - self._redis_connection.hdel(ProcessRedisName.JOB_DETAILS, job_name) - - # Rm job's log folder - job_folder = os.path.expanduser(f"{LocalPaths.MARO_PROCESS}/{job_name}") - shutil.rmtree(job_folder, True) - logger.info(f"Remove local temporary log folder {job_folder}.") - - def get_job_logs(self, job_name): - source_path = os.path.expanduser(f"{LocalPaths.MARO_PROCESS}/{job_name}") - if not os.path.exists(source_path): - logger.error(f"Cannot find the logs of {job_name}.") - - destination = os.path.join(os.getcwd(), job_name) - if os.path.exists(destination): - shutil.rmtree(destination) - shutil.copytree(source_path, destination) - logger.info(f"Dump logs in path: {destination}.") - - def list_job(self): - # Get all jobs - jobs = self._redis_connection.hgetall(ProcessRedisName.JOB_DETAILS) - for job_name, job_detail in jobs.items(): - job_name = job_name.decode() - job_detail = json.loads(job_detail) - - logger.info(job_detail) - - def start_schedule(self, deployment_path: str): - with open(deployment_path, "r") as fr: - schedule_detail = yaml.safe_load(fr) - - # push schedule details to Redis - self._redis_connection.hset( - ProcessRedisName.JOB_DETAILS, - schedule_detail["name"], - json.dumps(schedule_detail) - ) - - job_list = schedule_detail["job_names"] - # switch schedule details into job details - job_detail = copy.deepcopy(schedule_detail) - del job_detail["job_names"] - - for job_name in job_list: - job_detail["name"] = job_name - - # Push job details to redis - self._redis_connection.hset( - ProcessRedisName.JOB_DETAILS, - job_name, - json.dumps(job_detail) - ) - - self._push_pending_job(job_name) - - def stop_schedule(self, schedule_name: str): - if self._redis_connection.hexists(ProcessRedisName.JOB_DETAILS, schedule_name): - schedule_details = json.loads(self._redis_connection.hget(ProcessRedisName.JOB_DETAILS, schedule_name)) - else: - logger.error(f"Cannot find {schedule_name} in Redis. Please check schedule name.") - return - - if "job_names" not in schedule_details.keys(): - logger.error(f"'{schedule_name}' is not a schedule.") - return - - job_list = schedule_details["job_names"] - - for job_name in job_list: - self.stop_job(job_name) - - def get_job_details(self): - jobs = self._redis_connection.hgetall(ProcessRedisName.JOB_DETAILS) - for job_name, job_details_str in jobs.items(): - jobs[job_name] = json.loads(job_details_str) - - return list(jobs.values()) - - def get_job_queue(self): - pending_job_queue = self._redis_connection.lrange( - ProcessRedisName.PENDING_JOB_TICKETS, - 0, -1 - ) - killed_job_queue = self._redis_connection.lrange( - ProcessRedisName.KILLED_JOB_TICKETS, - 0, -1 - ) - return { - "pending_jobs": pending_job_queue, - "killed_jobs": killed_job_queue - } - - def get_resource(self): - return self._resource_redis.get_local_resource() - - def get_resource_usage(self, previous_length: int): - return self._resource_redis.get_local_resource_usage(previous_length) diff --git a/maro/cli/process/job.py b/maro/cli/process/job.py deleted file mode 100644 index 9f478cbb9..000000000 --- a/maro/cli/process/job.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - - -from maro.cli.process.executor import ProcessExecutor - - -def start_job(deployment_path: str, **kwargs): - executor = ProcessExecutor() - executor.start_job(deployment_path=deployment_path) - - -def stop_job(job_name: str, **kwargs): - executor = ProcessExecutor() - executor.stop_job(job_name=job_name) - - -def delete_job(job_name: str, **kwargs): - executor = ProcessExecutor() - executor.delete_job(job_name=job_name) - - -def list_jobs(**kwargs): - executor = ProcessExecutor() - executor.list_job() - - -def get_job_logs(job_name: str, **kwargs): - executor = ProcessExecutor() - executor.get_job_logs(job_name=job_name) diff --git a/maro/cli/process/schedule.py b/maro/cli/process/schedule.py deleted file mode 100644 index 0c7ca3188..000000000 --- a/maro/cli/process/schedule.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - - -from maro.cli.process.executor import ProcessExecutor - - -def start_schedule(deployment_path: str, **kwargs): - executor = ProcessExecutor() - executor.start_schedule(deployment_path=deployment_path) - - -def stop_schedule(schedule_name: str, **kwargs): - executor = ProcessExecutor() - executor.stop_schedule(schedule_name=schedule_name) diff --git a/maro/cli/process/template.py b/maro/cli/process/template.py deleted file mode 100644 index f1f8bb615..000000000 --- a/maro/cli/process/template.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import os -import shutil - -from maro.cli.utils.params import LocalPaths - - -def template(setting_deploy, export_path, **kwargs): - deploy_files = os.listdir(LocalPaths.MARO_PROCESS_DEPLOYMENT) - if not setting_deploy: - deploy_files.remove("process_setting_deployment.yml") - export_path = os.path.abspath(export_path) - for file_name in deploy_files: - if os.path.isfile(f"{LocalPaths.MARO_PROCESS_DEPLOYMENT}/{file_name}"): - shutil.copy(f"{LocalPaths.MARO_PROCESS_DEPLOYMENT}/{file_name}", export_path) diff --git a/maro/cli/process/utils/default_param.py b/maro/cli/process/utils/default_param.py deleted file mode 100644 index 41001f4d8..000000000 --- a/maro/cli/process/utils/default_param.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - - -process_setting = { - "redis_info": { - "host": "localhost", - "port": 19999 - }, - "redis_mode": "MARO", # one of MARO, customized. customized Redis won't exit after maro process clear. - "parallel_level": 1, - "keep_agent_alive": 1, # If 0 (False), agents will exit after 5 minutes of no pending jobs and running jobs. - "check_interval": 60, # seconds - "agent_countdown": 5 # how many times to shutdown agents about finding no job in Redis. -} diff --git a/maro/cli/process/utils/details.py b/maro/cli/process/utils/details.py deleted file mode 100644 index a95756946..000000000 --- a/maro/cli/process/utils/details.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import os -import signal -import subprocess -from typing import Union - -import psutil - - -def close_by_pid(pid: Union[int, list], recursive: bool = False): - if isinstance(pid, int): - if not psutil.pid_exists(pid): - return - - if recursive: - current_process = psutil.Process(pid) - children_process = current_process.children(recursive=False) - # May launch by JobTrackingAgent which is child process, so need close parent process first. - current_process.kill() - for child_process in children_process: - child_process.kill() - else: - os.kill(pid, signal.SIGKILL) - else: - for p in pid: - if psutil.pid_exists(p): - os.kill(p, signal.SIGKILL) - - -def get_child_pid(parent_pid): - command = f"ps -o pid --ppid {parent_pid} --noheaders" - get_children_pid_process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE) - children_pids = get_children_pid_process.stdout.read() - get_children_pid_process.wait(timeout=2) - - # Convert into list or int - try: - children_pids = int(children_pids) - except ValueError: - children_pids = children_pids.decode().split("\n") - children_pids = [int(pid) for pid in children_pids[:-1]] - - return children_pids - - -def get_redis_pid_by_port(port: int): - get_redis_pid_command = f"pidof 'redis-server *:{port}'" - get_redis_pid_process = subprocess.Popen(get_redis_pid_command, shell=True, stdout=subprocess.PIPE) - redis_pid = int(get_redis_pid_process.stdout.read()) - get_redis_pid_process.wait() - - return redis_pid diff --git a/maro/cli/utils/azure/acr.py b/maro/cli/utils/azure/acr.py new file mode 100644 index 000000000..bb5da0215 --- /dev/null +++ b/maro/cli/utils/azure/acr.py @@ -0,0 +1,17 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json + +from maro.cli.utils.subprocess import Subprocess + + +def login_acr(acr_name: str) -> None: + command = f"az acr login --name {acr_name}" + _ = Subprocess.run(command=command) + + +def list_acr_repositories(acr_name: str) -> list: + command = f"az acr repository list -n {acr_name}" + return_str = Subprocess.run(command=command) + return json.loads(return_str) diff --git a/maro/cli/utils/azure/aks.py b/maro/cli/utils/azure/aks.py new file mode 100644 index 000000000..44632b8cd --- /dev/null +++ b/maro/cli/utils/azure/aks.py @@ -0,0 +1,55 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import subprocess + +from azure.identity import DefaultAzureCredential +from azure.mgmt.authorization import AuthorizationManagementClient +from azure.mgmt.containerservice import ContainerServiceClient + +from maro.cli.utils.subprocess import Subprocess + + +def get_container_service_client(subscription: str): + return ContainerServiceClient(DefaultAzureCredential(), subscription) + + +def get_authorization_client(subscription: str): + return AuthorizationManagementClient() + + +def load_aks_context(resource_group: str, aks_name: str) -> None: + command = f"az aks get-credentials -g {resource_group} --name {aks_name}" + _ = Subprocess.run(command=command) + + +def get_aks(subscription: str, resource_group: str, aks_name: str) -> dict: + container_service_client = get_container_service_client(subscription) + return container_service_client.managed_clusters.get(resource_group, aks_name) + + +def attach_acr(resource_group: str, aks_name: str, acr_name: str) -> None: + subprocess.run(f"az aks update -g {resource_group} -n {aks_name} --attach-acr {acr_name}".split()) + + +def add_nodepool(resource_group: str, aks_name: str, nodepool_name: str, node_count: int, node_size: str) -> None: + command = ( + f"az aks nodepool add " + f"-g {resource_group} " + f"--cluster-name {aks_name} " + f"--name {nodepool_name} " + f"--node-count {node_count} " + f"--node-vm-size {node_size}" + ) + _ = Subprocess.run(command=command) + + +def scale_nodepool(resource_group: str, aks_name: str, nodepool_name: str, node_count: int) -> None: + command = ( + f"az aks nodepool scale " + f"-g {resource_group} " + f"--cluster-name {aks_name} " + f"--name {nodepool_name} " + f"--node-count {node_count}" + ) + _ = Subprocess.run(command=command) diff --git a/maro/cli/utils/azure/deployment.py b/maro/cli/utils/azure/deployment.py new file mode 100644 index 000000000..a2cc1cf5b --- /dev/null +++ b/maro/cli/utils/azure/deployment.py @@ -0,0 +1,27 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from .general import get_resource_client + + +def create_deployment( + subscription: str, + resource_group: str, + deployment_name: str, + template: dict, + params: dict, + sync: bool = True +) -> None: + params = {k: {"value": v} for k, v in params.items()} + resource_client = get_resource_client(subscription) + deployment_params = {"mode": "Incremental", "template": template, "parameters": params} + result = resource_client.deployments.begin_create_or_update( + resource_group, deployment_name, {"properties": deployment_params} + ) + if sync: + result.result() + + +def delete_deployment(subscription: str, resource_group: str, deployment_name: str) -> None: + resource_client = get_resource_client(subscription) + resource_client.deployments.begin_delete(resource_group, deployment_name) diff --git a/maro/cli/utils/azure/general.py b/maro/cli/utils/azure/general.py new file mode 100644 index 000000000..83e6968c7 --- /dev/null +++ b/maro/cli/utils/azure/general.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json +import os +import subprocess + +from azure.identity import DefaultAzureCredential +from azure.mgmt.resource import ResourceManagementClient + +from maro.cli.utils.subprocess import Subprocess + + +def set_subscription(subscription: str) -> None: + command = f"az account set --subscription {subscription}" + _ = Subprocess.run(command=command) + + +def get_version() -> dict: + command = "az version" + return_str = Subprocess.run(command=command) + return json.loads(return_str) + + +def get_resource_client(subscription: str): + return ResourceManagementClient(DefaultAzureCredential(), subscription) + + +def set_env_credentials(dump_path: str, service_principal_name: str): + os.makedirs(dump_path, exist_ok=True) + service_principal_file_path = os.path.join(dump_path, f"{service_principal_name}.json") + # If the service principal file does not exist, create one using the az CLI command. + # For details on service principals, refer to + # https://docs.microsoft.com/en-us/azure/active-directory/develop/app-objects-and-service-principals + if not os.path.exists(service_principal_file_path): + with open(service_principal_file_path, 'w') as fp: + subprocess.run( + f"az ad sp create-for-rbac --name {service_principal_name} --sdk-auth --role contributor".split(), + stdout=fp + ) + + with open(service_principal_file_path, 'r') as fp: + service_principal = json.load(fp) + + os.environ["AZURE_TENANT_ID"] = service_principal["tenantId"] + os.environ["AZURE_CLIENT_ID"] = service_principal["clientId"] + os.environ["AZURE_CLIENT_SECRET"] = service_principal["clientSecret"] + os.environ["AZURE_SUBSCRIPTION_ID"] = service_principal["subscriptionId"] + + +def connect_to_aks(resource_group: str, aks: str): + subprocess.run(f"az aks get-credentials --resource-group {resource_group} --name {aks}".split()) + + +def get_acr_push_permissions(service_principal_id: str, acr: str): + acr_id = json.loads( + subprocess.run(f"az acr show --name {acr} --query id".split(), stdout=subprocess.PIPE).stdout + ) + subprocess.run( + f"az role assignment create --assignee {service_principal_id} --scope {acr_id} --role acrpush".split() + ) + subprocess.run(f"az acr login --name {acr}".split()) diff --git a/maro/cli/utils/azure/resource_group.py b/maro/cli/utils/azure/resource_group.py new file mode 100644 index 000000000..8af9d76ba --- /dev/null +++ b/maro/cli/utils/azure/resource_group.py @@ -0,0 +1,44 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json + +from maro.cli.utils.subprocess import Subprocess +from maro.utils.exception.cli_exception import CommandExecutionError + +from .general import get_resource_client + + +def get_resource_group(resource_group: str) -> dict: + command = f"az group show --name {resource_group}" + try: + return_str = Subprocess.run(command=command) + return json.loads(return_str) + except CommandExecutionError: + return {} + + +def delete_resource_group(resource_group: str) -> None: + command = f"az group delete --yes --name {resource_group}" + _ = Subprocess.run(command=command) + + +# Chained Azure resource group operations +def create_resource_group(subscription: str, resource_group: str, location: str): + """Create the resource group if it does not exist. + + Args: + subscription (str): Azure subscription name. + resource group (str): Resource group name. + location (str): Reousrce group location. + + Returns: + None. + """ + resource_client = get_resource_client(subscription) + return resource_client.resource_groups.create_or_update(resource_group, {"location": location}) + + +def delete_resource_group_under_subscription(subscription: str, resource_group: str): + resource_client = get_resource_client(subscription) + return resource_client.resource_groups.begin_delete(resource_group) diff --git a/maro/cli/utils/azure/resources.py b/maro/cli/utils/azure/resources.py new file mode 100644 index 000000000..d8d025bb7 --- /dev/null +++ b/maro/cli/utils/azure/resources.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json + +from maro.cli.utils.subprocess import Subprocess + + +def list_resources(resource_group: str) -> list: + command = f"az resource list -g {resource_group}" + return_str = Subprocess.run(command=command) + return json.loads(return_str) + + +def delete_resources(resource_ids: list) -> None: + command = f"az resource delete --ids {' '.join(resource_ids)}" + _ = Subprocess.run(command=command) + + +def cleanup(cluster_name: str, resource_group: str) -> None: + # Get resource list + resource_list = list_resources(resource_group) + + # Filter resources + deletable_ids = [] + for resource in resource_list: + if resource["name"].startswith(cluster_name): + deletable_ids.append(resource["id"]) + + # Delete resources + if deletable_ids: + delete_resources(resource_ids=deletable_ids) diff --git a/maro/cli/utils/azure/storage.py b/maro/cli/utils/azure/storage.py new file mode 100644 index 000000000..3dcdb911f --- /dev/null +++ b/maro/cli/utils/azure/storage.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import datetime +import json +import os +from typing import Union + +from azure.core.exceptions import ResourceExistsError +from azure.storage.fileshare import ShareClient, ShareDirectoryClient + +from maro.cli.utils.subprocess import Subprocess + + +def get_storage_account_keys(resource_group: str, storage_account_name: str) -> dict: + command = f"az storage account keys list -g {resource_group} --account-name {storage_account_name}" + return_str = Subprocess.run(command=command) + return json.loads(return_str) + + +def get_storage_account_sas( + account_name: str, + services: str = "bqtf", + resource_types: str = "sco", + permissions: str = "rwdlacup", + expiry: str = (datetime.datetime.utcnow() + datetime.timedelta(days=365)).strftime("%Y-%m-%dT%H:%M:%S") + "Z" +) -> str: + command = ( + f"az storage account generate-sas --account-name {account_name} --services {services} " + f"--resource-types {resource_types} --permissions {permissions} --expiry {expiry}" + ) + sas_str = Subprocess.run(command=command).strip("\n").replace('"', "") + # logger.debug(sas_str) + return sas_str + + +def get_connection_string(storage_account_name: str) -> str: + """Get the connection string for a storage account. + + Args: + storage_account_name: The storage account name. + + Returns: + str: Connection string. + """ + command = f"az storage account show-connection-string --name {storage_account_name}" + return_str = Subprocess.run(command=command) + return json.loads(return_str)["connectionString"] + + +def get_fileshare(storage_account_name: str, fileshare_name: str): + connection_string = get_connection_string(storage_account_name) + share = ShareClient.from_connection_string(connection_string, fileshare_name) + try: + share.create_share() + except ResourceExistsError: + pass + + return share + + +def get_directory(share: Union[ShareClient, ShareDirectoryClient], name: str): + if isinstance(share, ShareClient): + directory = share.get_directory_client(directory_path=name) + try: + directory.create_directory() + except ResourceExistsError: + pass + + return directory + elif isinstance(share, ShareDirectoryClient): + try: + return share.create_subdirectory(name) + except ResourceExistsError: + return share.get_subdirectory_client(name) + + +def upload_to_fileshare(share: Union[ShareClient, ShareDirectoryClient], source_path: str, name: str = None): + if os.path.isdir(source_path): + if not name: + name = os.path.basename(source_path) + directory = get_directory(share, name) + for file in os.listdir(source_path): + upload_to_fileshare(directory, os.path.join(source_path, file)) + else: + with open(source_path, "rb") as fp: + share.upload_file(file_name=os.path.basename(source_path), data=fp) + + +def download_from_fileshare(share: ShareDirectoryClient, file_name: str, local_path: str): + file = share.get_file_client(file_name=file_name) + with open(local_path, "wb") as fp: + fp.write(file.download_file().readall()) + + +def delete_directory(share: Union[ShareClient, ShareDirectoryClient], name: str, recursive: bool = True): + share.delete_directory(directory_name=name) diff --git a/maro/cli/utils/azure/vm.py b/maro/cli/utils/azure/vm.py new file mode 100644 index 000000000..f7b12ff8c --- /dev/null +++ b/maro/cli/utils/azure/vm.py @@ -0,0 +1,49 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json + +from maro.cli.utils.subprocess import Subprocess + + +def list_ip_addresses(resource_group: str, vm_name: str) -> list: + command = f"az vm list-ip-addresses -g {resource_group} --name {vm_name}" + return_str = Subprocess.run(command=command) + return json.loads(return_str) + + +def start_vm(resource_group: str, vm_name: str) -> None: + command = f"az vm start -g {resource_group} --name {vm_name}" + _ = Subprocess.run(command=command) + + +def stop_vm(resource_group: str, vm_name: str) -> None: + command = f"az vm stop -g {resource_group} --name {vm_name}" + _ = Subprocess.run(command=command) + + +def list_vm_sizes(location: str) -> list: + command = f"az vm list-sizes -l {location}" + return_str = Subprocess.run(command=command) + return json.loads(return_str) + + +def deallocate_vm(resource_group: str, vm_name: str) -> None: + command = f"az vm deallocate --resource-group {resource_group} --name {vm_name}" + _ = Subprocess.run(command=command) + + +def generalize_vm(resource_group: str, vm_name: str) -> None: + command = f"az vm generalize --resource-group {resource_group} --name {vm_name}" + _ = Subprocess.run(command=command) + + +def create_image_from_vm(resource_group: str, image_name: str, vm_name: str) -> None: + command = f"az image create --resource-group {resource_group} --name {image_name} --source {vm_name}" + _ = Subprocess.run(command=command) + + +def get_image_resource_id(resource_group: str, image_name: str) -> str: + command = f"az image show --resource-group {resource_group} --name {image_name}" + return_str = Subprocess.run(command=command) + return json.loads(return_str)["id"] diff --git a/maro/cli/utils/common.py b/maro/cli/utils/common.py index 50927e922..e3e0bfb50 100644 --- a/maro/cli/utils/common.py +++ b/maro/cli/utils/common.py @@ -1,7 +1,55 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import os +import subprocess import sys +from collections import deque + +import psutil + +from maro.utils import Logger + + +def close_by_pid(pid: int, recursive: bool = True): + if not psutil.pid_exists(pid): + return + + proc = psutil.Process(pid) + if recursive: + for child in proc.children(recursive=recursive): + child.kill() + + proc.kill() + + +def get_child_pids(parent_pid): + # command = f"ps -o pid --ppid {parent_pid} --noheaders" + # get_children_pid_process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE) + # children_pids = get_children_pid_process.stdout.read() + # get_children_pid_process.wait(timeout=2) + + # # Convert into list or int + # try: + # children_pids = int(children_pids) + # except ValueError: + # children_pids = children_pids.decode().split("\n") + # children_pids = [int(pid) for pid in children_pids[:-1]] + + # return children_pids + try: + return [child.pid for child in psutil.Process(parent_pid).children(recursive=True)] + except psutil.NoSuchProcess: + print(f"No process with PID {parent_pid} found") + return + + +def get_redis_pid_by_port(port: int): + get_redis_pid_command = f"pidof 'redis-server *:{port}'" + get_redis_pid_process = subprocess.Popen(get_redis_pid_command, shell=True, stdout=subprocess.PIPE) + redis_pid = int(get_redis_pid_process.stdout.read()) + get_redis_pid_process.wait() + return redis_pid def exit(state: int = 0, msg: str = None): @@ -10,3 +58,75 @@ def exit(state: int = 0, msg: str = None): sys.stderr.write(msg) sys.exit(state) + + +def get_last_k_lines(file_name: str, k: int): + """ + Helper function to retrieve the last K lines from a file in a memory-efficient way. + + Code slightly adapted from https://thispointer.com/python-get-last-n-lines-of-a-text-file-like-tail-command/ + """ + # Create an empty list to keep the track of last k lines + lines = deque() + # Open file for reading in binary mode + with open(file_name, 'rb') as fp: + # Move the cursor to the end of the file + fp.seek(0, os.SEEK_END) + # Create a buffer to keep the last read line + buffer = bytearray() + # Get the current position of pointer i.e eof + ptr = fp.tell() + # Loop till pointer reaches the top of the file + while ptr >= 0: + # Move the file pointer to the location pointed by ptr + fp.seek(ptr) + # Shift pointer location by -1 + ptr -= 1 + # read that byte / character + new_byte = fp.read(1) + # If the read byte is new line character then it means one line is read + if new_byte != b'\n': + # If last read character is not eol then add it in buffer + buffer.extend(new_byte) + elif buffer: + lines.appendleft(buffer.decode()[::-1]) + if len(lines) == k: + return lines + # Reinitialize the byte array to save next line + buffer.clear() + + # As file is read completely, if there is still data in buffer, then it's the first of the last K lines. + if buffer: + lines.appendleft(buffer.decode()[::-1]) + + return lines + + +def show_log(log_path: str, tail: int = -1, logger: Logger = None): + print_fn = logger.info if logger else print + if tail == -1: + with open(log_path, "r") as fp: + for line in fp: + print_fn(line.rstrip('\n')) + else: + for line in get_last_k_lines(log_path, tail): + print_fn(line) + + +def format_env_vars(env: dict, mode: str = "proc"): + if mode == "proc": + return env + + if mode == "docker": + env_opt_list = [] + for key, val in env.items(): + env_opt_list.extend(["--env", f"{key}={val}"]) + return env_opt_list + + if mode == "docker-compose": + return [f"{key}={val}" for key, val in env.items()] + + if mode == "k8s": + return [{"name": key, "value": val} for key, val in env.items()] + + raise ValueError(f"'mode' should be one of 'proc', 'docker', 'docker-compose', 'k8s', got {mode}") diff --git a/maro/cli/utils/docker.py b/maro/cli/utils/docker.py new file mode 100644 index 000000000..a22863697 --- /dev/null +++ b/maro/cli/utils/docker.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import docker + + +def image_exists(image_name: str): + try: + client = docker.from_env() + client.images.get(image_name) + return True + except docker.errors.ImageNotFound: + return False + + +def build_image(context: str, docker_file_path: str, image_name: str): + client = docker.from_env() + with open(docker_file_path, "r"): + client.images.build( + path=context, + tag=image_name, + quiet=False, + rm=True, + custom_context=False, + dockerfile=docker_file_path + ) + + +def push(local_image_name: str, repository: str): + client = docker.from_env() + image = client.images.get(local_image_name) + acr_tag = f"{repository}/{local_image_name}" + image.tag(acr_tag) + # subprocess.run(f"docker push {acr_tag}".split()) + client.images.push(acr_tag) + print(f"Pushed image to {acr_tag}") diff --git a/maro/cli/utils/params.py b/maro/cli/utils/params.py index 9bfd173de..51280c5b5 100644 --- a/maro/cli/utils/params.py +++ b/maro/cli/utils/params.py @@ -38,23 +38,3 @@ class LocalParams: CPU_USAGE = "local_resource:cpu_usage_per_core" MEMORY_USAGE = "local_resource:memory_usage" GPU_USAGE = "local_resource:gpu_memory_usage" - - -class LocalPaths: - """Only use by maro process cli""" - MARO_PROCESS = "~/.maro/clusters/process" - MARO_PROCESS_AGENT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../process/agent/job_agent.py") - MARO_RESOURCE_AGENT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../process/agent/resource_agent.py") - MARO_PROCESS_DEPLOYMENT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../process/deployment") - MARO_GRASS_LOCAL_AGENT = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "../grass/lib/services/master_agent/local_agent.py" - ) - - -class ProcessRedisName: - """Record Redis elements name, and only for maro process""" - PENDING_JOB_TICKETS = "process:pending_job_tickets" - KILLED_JOB_TICKETS = "process:killed_job_tickets" - JOB_DETAILS = "process:job_details" - SETTING = "process:setting" diff --git a/maro/communication/dist_decorator.py b/maro/communication/dist_decorator.py index 40839e86d..3cc88a5fb 100644 --- a/maro/communication/dist_decorator.py +++ b/maro/communication/dist_decorator.py @@ -26,7 +26,7 @@ def __init__(self, *args, **kwargs): self.local_instance = cls(*args, **kwargs) self.proxy = proxy self._handler_function = {} - self._registry_table = RegisterTable(self.proxy.peers_name) + self._registry_table = RegisterTable(self.proxy.peers) # Use functools.partial to freeze handling function's local_instance and proxy # arguments to self.local_instance and self.proxy. for constraint, handler_fun in handler_dict.items(): diff --git a/maro/communication/driver/zmq_driver.py b/maro/communication/driver/zmq_driver.py index 1998fcce0..3fe73ad4e 100644 --- a/maro/communication/driver/zmq_driver.py +++ b/maro/communication/driver/zmq_driver.py @@ -69,7 +69,7 @@ def _setup_sockets(self): """ self._unicast_receiver = self._zmq_context.socket(zmq.PULL) unicast_receiver_port = self._unicast_receiver.bind_to_random_port(f"{self._protocol}://*") - self._logger.info(f"Receive message via unicasting at {self._ip_address}:{unicast_receiver_port}.") + self._logger.debug(f"Receive message via unicasting at {self._ip_address}:{unicast_receiver_port}.") # Dict about zmq.PUSH sockets, fulfills in self.connect. self._unicast_sender_dict = {} @@ -80,7 +80,7 @@ def _setup_sockets(self): self._broadcast_receiver = self._zmq_context.socket(zmq.SUB) self._broadcast_receiver.setsockopt(zmq.SUBSCRIBE, self._component_type.encode()) broadcast_receiver_port = self._broadcast_receiver.bind_to_random_port(f"{self._protocol}://*") - self._logger.info(f"Subscriber message at {self._ip_address}:{broadcast_receiver_port}.") + self._logger.debug(f"Subscriber message at {self._ip_address}:{broadcast_receiver_port}.") # Record own sockets' address. self._address = { @@ -122,10 +122,10 @@ def connect(self, peers_address_dict: Dict[str, Dict[str, str]]): self._unicast_sender_dict[peer_name] = self._zmq_context.socket(zmq.PUSH) self._unicast_sender_dict[peer_name].setsockopt(zmq.SNDTIMEO, self._send_timeout) self._unicast_sender_dict[peer_name].connect(address) - self._logger.info(f"Connects to {peer_name} via unicasting.") + self._logger.debug(f"Connects to {peer_name} via unicasting.") elif int(socket_type) == zmq.SUB: self._broadcast_sender.connect(address) - self._logger.info(f"Connects to {peer_name} via broadcasting.") + self._logger.debug(f"Connects to {peer_name} via broadcasting.") else: raise SocketTypeError(f"Unrecognized socket type {socket_type}.") except Exception as e: @@ -158,13 +158,13 @@ def disconnect(self, peers_address_dict: Dict[str, Dict[str, str]]): raise PeersDisconnectionError(f"Driver cannot disconnect to {peer_name}! Due to {str(e)}") self._disconnected_peer_name_list.append(peer_name) - self._logger.info(f"Disconnected with {peer_name}.") + self._logger.debug(f"Disconnected with {peer_name}.") - def receive(self, is_continuous: bool = True, timeout: int = None): + def receive(self, timeout: int = None): """Receive message from ``zmq.POLLER``. Args: - is_continuous (bool): Continuously receive message or not. Defaults to True. + timeout (int): Timeout for polling. If the first poll times out, the function returns None. Yields: recv_message (Message): The received message from the poller. @@ -184,13 +184,38 @@ def receive(self, is_continuous: bool = True, timeout: int = None): recv_message = pickle.loads(recv_message) self._logger.debug(f"Receive a message from {recv_message.source} through broadcast receiver.") else: - self._logger.debug(f"Cannot receive any message within {receive_timeout}.") + self._logger.debug(f"No message received within {receive_timeout}.") return yield recv_message - if not is_continuous: - break + def receive_once(self, timeout: int = None): + """Receive a single message from ``zmq.POLLER``. + + Args: + timeout (int): Time-out for ZMQ polling. If the first poll times out, the function returns None. + + Returns: + recv_message (Message): The received message from the poller or None if the poller times out. + """ + receive_timeout = timeout if timeout else self._receive_timeout + try: + sockets = dict(self._poller.poll(receive_timeout)) + except Exception as e: + raise DriverReceiveError(f"Driver cannot receive message as {e}") + + if self._unicast_receiver in sockets: + recv_message = self._unicast_receiver.recv_pyobj() + self._logger.debug(f"Receive a message from {recv_message.source} through unicast receiver.") + elif self._broadcast_receiver in sockets: + _, recv_message = self._broadcast_receiver.recv_multipart() + recv_message = pickle.loads(recv_message) + self._logger.debug(f"Receive a message from {recv_message.source} through broadcast receiver.") + else: + self._logger.debug(f"No message received within {receive_timeout}.") + return + + return recv_message def send(self, message: Message): """Send message. diff --git a/maro/communication/message.py b/maro/communication/message.py index 6ae4cdc8c..92e94bff0 100644 --- a/maro/communication/message.py +++ b/maro/communication/message.py @@ -48,35 +48,35 @@ class Message(object): tag (str|Enum): Message tag, which is customized by the user, for specific application logic. source (str): The sender of message. destination (str): The receiver of message. - payload (object): Message payload, such as model parameters, experiences, etc. Defaults to None. + body (object): Message body, such as model parameters, experiences, etc. Defaults to None. session_id (str): Message belonged session id, it will be generated automatically by default, you can use it to group message based on your application logic. """ - def __init__(self, tag: Union[str, Enum], source: str, destination: str, payload=None): + def __init__(self, tag: Union[str, Enum], source: str, destination: str, body=None): self.tag = tag self.source = source self.destination = destination - self.payload = {} if payload is None else payload + self.body = {} if body is None else body self.session_id = session_id_generator(self.source, self.destination) self.message_id = str(uuid.uuid1()) def __repr__(self): return "; \n".join([f"{k} = {v}" for k, v in vars(self).items()]) - def reply(self, tag: Union[str, Enum] = None, payload=None): + def reply(self, tag: Union[str, Enum] = None, body=None): self.source, self.destination = self.destination, self.source if tag: self.tag = tag - self.payload = payload + self.body = body self.message_id = str(uuid.uuid1()) - def forward(self, destination: str, tag: Union[str, Enum] = None, payload=None): + def forward(self, destination: str, tag: Union[str, Enum] = None, body=None): self.source = self.destination self.destination = destination if tag: self.tag = tag - self.payload = payload + self.body = body self.message_id = str(uuid.uuid1()) @@ -94,11 +94,11 @@ def __init__( self, tag: Union[str, Enum], source: str, destination: str, - payload=None, + body=None, session_type: SessionType = SessionType.TASK, session_stage=None ): - super().__init__(tag, source, destination, payload) + super().__init__(tag, source, destination, body) self.session_type = session_type if self.session_type == SessionType.TASK: diff --git a/maro/communication/proxy.py b/maro/communication/proxy.py index 31e840af7..000fb5be5 100644 --- a/maro/communication/proxy.py +++ b/maro/communication/proxy.py @@ -16,7 +16,7 @@ import redis # private lib -from maro.utils import DummyLogger, InternalLogger +from maro.utils import Logger from maro.utils.exception.communication_exception import InformationUncompletedError, PeersMissError, PendingToSend from maro.utils.exit_code import KILL_ALL_EXIT_CODE, NON_RESTART_EXIT_CODE @@ -27,8 +27,10 @@ _PEER_INFO = namedtuple("PEER_INFO", ["hash_table_name", "expected_number"]) HOST = default_parameters.proxy.redis.host PORT = default_parameters.proxy.redis.port -MAX_RETRIES = default_parameters.proxy.redis.max_retries -BASE_RETRY_INTERVAL = default_parameters.proxy.redis.base_retry_interval +INITIAL_REDIS_CONNECT_RETRY_INTERVAL = default_parameters.proxy.redis.initial_retry_interval +MAX_REDIS_CONNECT_RETRIES = default_parameters.proxy.redis.max_retries +INITIAL_PEER_DISCOVERY_RETRY_INTERVAL = default_parameters.proxy.peer_discovery.initial_retry_interval +MAX_PEER_DISCOVERY_RETRIES = default_parameters.proxy.peer_discovery.max_retries DELAY_FOR_SLOW_JOINER = default_parameters.proxy.delay_for_slow_joiner ENABLE_REJOIN = default_parameters.proxy.peer_rejoin.enable # Only enable at real k8s cluster or grass cluster PEERS_CATCH_LIFETIME = default_parameters.proxy.peer_rejoin.peers_catch_lifetime @@ -55,9 +57,12 @@ class Proxy: Defaults to ``DriverType.ZMQ``. driver_parameters (Dict): The arguments for communication driver class initial. Defaults to None. redis_address (Tuple): Hostname and port of the Redis server. Defaults to ("localhost", 6379). - max_retries (int): Maximum number of retries before raising an exception. Defaults to 5. - retry_interval_base_value (float): The time interval between attempts. Defaults to 0.1. - log_enable (bool): Open internal logger or not. Defaults to True. + initial_redis_connect_retry_interval: Base value for the wait time between retries to connect to Redis. + Retries follow the exponential backoff algorithm. Defaults to 0.1. + max_redis_connect_retries: Maximum number of retries to connect to Redis. Defaults to 5. + initial_peer_discovery_retry_interval: Base value for the wait time between retries to find peers. + Retries follow the exponential backoff algorithm. Defaults to 0.1. + max_peer_discovery_retries: Maximum number of retries to find peers. Defaults to 5. enable_rejoin (bool): Allow peers rejoin or not. Defaults to False, and must use with maro cli. minimal_peers Union[int, dict]: The minimal number of peers for each peer type. peers_catch_lifetime (int): The lifetime for onboard peers' information. @@ -73,12 +78,14 @@ def __init__( group_name: str, component_type: str, expected_peers: dict, + component_name: str = None, driver_type: DriverType = DriverType.ZMQ, driver_parameters: dict = None, redis_address: Tuple = (HOST, PORT), - max_retries: int = MAX_RETRIES, - retry_interval_base_value: float = BASE_RETRY_INTERVAL, - log_enable: bool = True, + initial_redis_connect_retry_interval: int = INITIAL_REDIS_CONNECT_RETRY_INTERVAL, + max_redis_connect_retries: int = MAX_REDIS_CONNECT_RETRIES, + initial_peer_discovery_retry_interval: int = INITIAL_PEER_DISCOVERY_RETRY_INTERVAL, + max_peer_discovery_retries: int = MAX_PEER_DISCOVERY_RETRIES, enable_rejoin: bool = ENABLE_REJOIN, minimal_peers: Union[int, dict] = MINIMAL_PEERS, peers_catch_lifetime: int = PEERS_CATCH_LIFETIME, @@ -91,15 +98,16 @@ def __init__( self._group_name = group_name self._component_type = component_type self._redis_hash_name = f"{self._group_name}:{self._component_type}" - if "COMPONENT_NAME" in os.environ: - self._name = os.getenv("COMPONENT_NAME") + if component_name is not None: + self._name = component_name else: unique_id = str(uuid.uuid1()).replace("-", "") - self._name = f"{self._component_type}_proxy_{unique_id}" - self._max_retries = max_retries - self._retry_interval_base_value = retry_interval_base_value - self._log_enable = log_enable - self._logger = InternalLogger(component_name=self._name) if self._log_enable else DummyLogger() + self._name = f"{self._component_type}_{unique_id}" + self._initial_redis_connect_retry_interval = initial_redis_connect_retry_interval + self._max_redis_connect_retries = max_redis_connect_retries + self._initial_peer_discovery_retry_interval = initial_peer_discovery_retry_interval + self._max_peer_discovery_retries = max_peer_discovery_retries + self._logger = Logger(".".join([self._name, "proxy"])) # TODO:In multiprocess with spawn start method, the driver must be initiated before the Redis. # Otherwise it will cause Error 9: Bad File Descriptor in proxy.__del__(). Root cause not found. @@ -112,12 +120,28 @@ def __init__( self._logger.error(f"Unsupported driver type {driver_type}, please use DriverType class.") sys.exit(NON_RESTART_EXIT_CODE) - # Initialize the Redis. + # Initialize connection to the redis server. self._redis_connection = redis.Redis(host=redis_address[0], port=redis_address[1], socket_keepalive=True) - try: - self._redis_connection.ping() - except Exception as e: - self._logger.error(f"{self._name} failure to connect to redis server due to {e}") + next_retry, success = self._initial_redis_connect_retry_interval, False + for _ in range(self._max_redis_connect_retries): + try: + self._redis_connection.ping() + success = True + break + except Exception as e: + self._logger.error( + f"{self._name} failed to connect to Redis due to {e}. Retrying in {next_retry} seconds." + ) + time.sleep(next_retry) + next_retry *= 2 + + if success: + self._logger.debug( + f"{self._name} is successfully connected to the redis server " + f"at {redis_address[0]}:{redis_address[1]}." + ) + else: + self._logger.error(f"{self._name} failed to connect to the redis server.") sys.exit(NON_RESTART_EXIT_CODE) # Record the peer's redis information. @@ -209,7 +233,6 @@ def _register_redis(self): the value of table is the peer's socket address. """ self._redis_connection.hset(self._redis_hash_name, self._name, json.dumps(self._driver.address)) - # Handle interrupt signal for clearing Redis record. try: signal.signal(signal.SIGINT, self._signal_handler) @@ -226,47 +249,40 @@ def _get_peers_list(self): if not self._peers_info_dict: raise PeersMissError(f"Cannot get {self._name}\'s peers.") - for peer_type in self._peers_info_dict.keys(): - peer_hash_name, peer_number = self._peers_info_dict[peer_type] - retry_number = 0 - expected_peers_name = [] - while retry_number < self._max_retries: - if self._redis_connection.hlen(peer_hash_name) >= peer_number: - expected_peers_name = self._redis_connection.hkeys(peer_hash_name) - expected_peers_name = [peer.decode() for peer in expected_peers_name] - if len(expected_peers_name) > peer_number: - expected_peers_name = expected_peers_name[:peer_number] - self._logger.info(f"{self._name} successfully get all {peer_type}\'s name.") + for peer_type, (peer_hash_name, num_expected) in self._peers_info_dict.items(): + registered_peers, next_retry = [], self._initial_peer_discovery_retry_interval + for _ in range(self._max_peer_discovery_retries): + if self._redis_connection.hlen(peer_hash_name) >= num_expected: + registered_peers = [peer.decode() for peer in self._redis_connection.hkeys(peer_hash_name)] + if len(registered_peers) > num_expected: + del registered_peers[num_expected:] + self._logger.debug(f"{self._name} successfully get all {peer_type}\'s names.") break else: self._logger.warn( - f"{self._name} failed to get {peer_type}\'s name. Retrying in " - f"{self._retry_interval_base_value * (2 ** retry_number)} seconds." + f"{self._name} failed to get {peer_type}\'s name. Retrying in {next_retry} seconds." ) - time.sleep(self._retry_interval_base_value * (2 ** retry_number)) - retry_number += 1 + time.sleep(next_retry) + next_retry *= 2 - if not expected_peers_name: + if not registered_peers: raise InformationUncompletedError( - f"{self._name} failure to get enough number of {peer_type} from redis." + f"{self._name} failed to get the required number of {peer_type}s from redis." ) - self._onboard_peer_dict[peer_type] = {peer_name: None for peer_name in expected_peers_name} + self._onboard_peer_dict[peer_type] = {peer_name: None for peer_name in registered_peers} self._onboard_peers_start_time = time.time() def _build_connection(self): """Grabbing all peers' address from Redis, and connect all peers in driver.""" - for peer_type in self._peers_info_dict.keys(): + for peer_type, info in self._peers_info_dict.items(): name_list = list(self._onboard_peer_dict[peer_type].keys()) try: - peers_socket_value = self._redis_connection.hmget( - self._peers_info_dict[peer_type].hash_table_name, - name_list - ) + peers_socket_value = self._redis_connection.hmget(info.hash_table_name, name_list) for idx, peer_name in enumerate(name_list): self._onboard_peer_dict[peer_type][peer_name] = json.loads(peers_socket_value[idx]) - self._logger.info(f"{self._name} successfully get {peer_name}\'s socket address") + self._logger.debug(f"{self._name} successfully get {peer_name}\'s socket address") except Exception as e: raise InformationUncompletedError(f"{self._name} failed to get {name_list}\'s address. Due to {str(e)}") @@ -288,19 +304,27 @@ def component_type(self) -> str: return self._component_type @property - def peers_name(self) -> Dict: + def peers(self) -> Dict: """Dict: The ``Dict`` of all connected peers' names, stored by peer type.""" return { peer_type: list(self._onboard_peer_dict[peer_type].keys()) for peer_type in self._peers_info_dict.keys() } - def receive(self, is_continuous: bool = True, timeout: int = None): - """Receive messages from communication driver. + def receive(self, timeout: int = None): + """Enter an infinite loop of receiving messages from the communication driver. + + Args: + timeout (int): Timeout for each receive attempt. If the first attempt times out, the function returns None. + """ + return self._driver.receive(timeout=timeout) + + def receive_once(self, timeout: int = None): + """Receive a single message from the communication driver. Args: - is_continuous (bool): Continuously receive message or not. Defaults to True. + timeout (int): Timeout for receive attempt. """ - return self._driver.receive(is_continuous, timeout=timeout) + return self._driver.receive_once(timeout=timeout) def receive_by_id(self, targets: List[str], timeout: int = None) -> List[Message]: """Receive target messages from communication driver. @@ -334,7 +358,7 @@ def receive_by_id(self, targets: List[str], timeout: int = None) -> List[Message return received_messages # Wait for incoming messages. - for msg in self._driver.receive(is_continuous=True, timeout=timeout): + for msg in self._driver.receive(timeout=timeout): if not msg: return received_messages @@ -353,17 +377,17 @@ def _scatter( self, tag: Union[str, Enum], session_type: SessionType, - destination_payload_list: list + destination_body_list: list ) -> List[str]: """Scatters a list of data to peers, and return list of session id.""" session_id_list = [] - for destination, payload in destination_payload_list: + for destination, body in destination_body_list: message = SessionMessage( tag=tag, source=self._name, destination=destination, - payload=payload, + body=body, session_type=session_type ) send_result = self.isend(message) @@ -376,7 +400,7 @@ def scatter( self, tag: Union[str, Enum], session_type: SessionType, - destination_payload_list: list, + destination_body_list: list, timeout: int = -1 ) -> List[Message]: """Scatters a list of data to peers, and return replied messages. @@ -384,15 +408,15 @@ def scatter( Args: tag (str|Enum): Message's tag. session_type (Enum): Message's session type. - destination_payload_list ([Tuple(str, object)]): The destination-payload list. + destination_body_list ([Tuple(str, object)]): The destination-body list. The first item of the tuple in list is the message destination, - and the second item of the tuple in list is the message payload. + and the second item of the tuple in list is the message body. Returns: List[Message]: List of replied message. """ return self.receive_by_id( - targets=self._scatter(tag, session_type, destination_payload_list), + targets=self._scatter(tag, session_type, destination_body_list), timeout=timeout ) @@ -400,31 +424,31 @@ def iscatter( self, tag: Union[str, Enum], session_type: SessionType, - destination_payload_list: list + destination_body_list: list ) -> List[str]: """Scatters a list of data to peers, and return list of message id. Args: tag (str|Enum): Message's tag. session_type (Enum): Message's session type. - destination_payload_list ([Tuple(str, object)]): The destination-payload list. + destination_body_list ([Tuple(str, object)]): The destination-body list. The first item of the tuple in list is the message's destination, - and the second item of the tuple in list is the message's payload. + and the second item of the tuple in list is the message's body. Returns: List[str]: List of message's session id. """ - return self._scatter(tag, session_type, destination_payload_list) + return self._scatter(tag, session_type, destination_body_list) def _broadcast( self, component_type: str, tag: Union[str, Enum], session_type: SessionType, - payload=None + body=None ) -> List[str]: """Broadcast message to all peers, and return list of session id.""" - if component_type not in list(self._onboard_peer_dict.keys()): + if component_type not in self._onboard_peer_dict: self._logger.error( f"peer_type: {component_type} cannot be recognized. Please check the input of proxy.broadcast." ) @@ -437,7 +461,7 @@ def _broadcast( tag=tag, source=self._name, destination=component_type, - payload=payload, + body=body, session_type=session_type ) @@ -450,7 +474,7 @@ def broadcast( component_type: str, tag: Union[str, Enum], session_type: SessionType, - payload=None, + body=None, timeout: int = None ) -> List[Message]: """Broadcast message to all peers, and return all replied messages. @@ -459,13 +483,13 @@ def broadcast( component_type (str): Broadcast to all peers in this type. tag (str|Enum): Message's tag. session_type (Enum): Message's session type. - payload (object): The true data. Defaults to None. + body (object): The true data. Defaults to None. Returns: List[Message]: List of replied messages. """ return self.receive_by_id( - targets=self._broadcast(component_type, tag, session_type, payload), + targets=self._broadcast(component_type, tag, session_type, body), timeout=timeout ) @@ -474,7 +498,7 @@ def ibroadcast( component_type: str, tag: Union[str, Enum], session_type: SessionType, - payload=None + body=None ) -> List[str]: """Broadcast message to all subscribers, and return list of message's session id. @@ -482,12 +506,12 @@ def ibroadcast( component_type (str): Broadcast to all peers in this type. tag (str|Enum): Message's tag. session_type (Enum): Message's session type. - payload (object): The true data. Defaults to None. + body (object): The true data. Defaults to None. Returns: List[str]: List of message's session id which related to the replied message. """ - return self._broadcast(component_type, tag, session_type, payload) + return self._broadcast(component_type, tag, session_type, body) def _send(self, message: Message) -> Union[List[str], None]: """Send a message to a remote peer. @@ -509,10 +533,10 @@ def _send(self, message: Message) -> Union[List[str], None]: # Check message cache. if ( self._enable_message_cache - and message.destination in list(self._onboard_peer_dict[peer_type].keys()) - and message.destination in list(self._message_cache_for_exited_peers.keys()) + and message.destination in self._onboard_peer_dict[peer_type] + and message.destination in self._message_cache_for_exited_peers ): - self._logger.info(f"Sending pending message to {message.destination}.") + self._logger.debug(f"Sending pending message to {message.destination}.") for pending_message in self._message_cache_for_exited_peers[message.destination]: self._driver.send(pending_message) session_id_list.append(pending_message.session_id) @@ -563,7 +587,7 @@ def reply( self, message: Union[SessionMessage, Message], tag: Union[str, Enum] = None, - payload=None, + body=None, ack_reply: bool = False ) -> List[str]: """Reply a received message. @@ -571,13 +595,13 @@ def reply( Args: message (Message): The message need to reply. tag (str|Enum): New message tag, if None, keeps the original message's tag. Defaults to None. - payload (object): New message payload, if None, keeps the original message's payload. Defaults to None. + body (object): New message body, if None, keeps the original message's body. Defaults to None. ack_reply (bool): If True, it is acknowledge reply. Defaults to False. Returns: List[str]: Message belonged session id. """ - message.reply(tag=tag, payload=payload) + message.reply(tag=tag, body=body) if isinstance(message, SessionMessage): if message.session_type == SessionType.TASK: session_stage = TaskSessionStage.RECEIVE if ack_reply else TaskSessionStage.COMPLETE @@ -592,7 +616,7 @@ def forward( message: Union[SessionMessage, Message], destination: str, tag: Union[str, Enum] = None, - payload=None + body=None ) -> List[str]: """Forward a received message. @@ -600,12 +624,12 @@ def forward( message (Message): The message need to forward. destination (str): The receiver of message. tag (str|Enum): New message tag, if None, keeps the original message's tag. Defaults to None. - payload (object): Message payload, if None, keeps the original message's payload. Defaults to None. + body (object): Message body, if None, keeps the original message's body. Defaults to None. Returns: List[str]: Message belonged session id. """ - message.forward(destination=destination, tag=tag, payload=payload) + message.forward(destination=destination, tag=tag, body=body) return self.isend(message) def _check_peers_update(self): @@ -635,18 +659,18 @@ def _check_peers_update(self): for peer_name in union_peer_name: # Add new peers (new key added on redis). if peer_name not in list(self._onboard_peer_dict[peer_type].keys()): - self._logger.info(f"PEER_REJOIN: New peer {peer_name} join.") + self._logger.debug(f"PEER_REJOIN: New peer {peer_name} join.") self._driver.connect({peer_name: onboard_peers_dict_on_redis[peer_name]}) self._onboard_peer_dict[peer_type][peer_name] = onboard_peers_dict_on_redis[peer_name] # Delete out of date peers (old key deleted on local) elif peer_name not in onboard_peers_dict_on_redis.keys(): - self._logger.info(f"PEER_REJOIN: Peer {peer_name} exited.") + self._logger.debug(f"PEER_REJOIN: Peer {peer_name} exited.") self._driver.disconnect({peer_name: self._onboard_peer_dict[peer_type][peer_name]}) del self._onboard_peer_dict[peer_type][peer_name] else: # Peer's ip/port updated, re-connect (value update on redis). if onboard_peers_dict_on_redis[peer_name] != self._onboard_peer_dict[peer_type][peer_name]: - self._logger.info(f"PEER_REJOIN: Peer {peer_name} rejoin.") + self._logger.debug(f"PEER_REJOIN: Peer {peer_name} rejoin.") self._driver.disconnect({peer_name: self._onboard_peer_dict[peer_type][peer_name]}) self._driver.connect({peer_name: onboard_peers_dict_on_redis[peer_name]}) self._onboard_peer_dict[peer_type][peer_name] = onboard_peers_dict_on_redis[peer_name] @@ -715,7 +739,7 @@ def _push_message_to_message_cache(self, message: Message): return self._message_cache_for_exited_peers[peer_name].append(message) - self._logger.info(f"Temporarily save message {message.session_id} to message cache.") + self._logger.debug(f"Temporarily save message {message.session_id} to message cache.") def close(self): self._redis_connection.hdel(self._redis_hash_name, self._name) diff --git a/maro/communication/registry_table.py b/maro/communication/registry_table.py index 5a0197c8f..0c85843ff 100644 --- a/maro/communication/registry_table.py +++ b/maro/communication/registry_table.py @@ -272,7 +272,6 @@ def get(self) -> List[Tuple[callable, List[Message]]]: for event, handler_fn in self._event_handler_dict.items(): message_list = event.get_qualified_message() - if message_list: satisfied_handler_fn.append((handler_fn, message_list)) diff --git a/maro/communication/utils/default_parameters.py b/maro/communication/utils/default_parameters.py index a1e9a70f0..2a7e2c92d 100644 --- a/maro/communication/utils/default_parameters.py +++ b/maro/communication/utils/default_parameters.py @@ -6,11 +6,15 @@ proxy = convert_dottable({ "fault_tolerant": False, "delay_for_slow_joiner": 3, + "peer_discovery": { + "initial_retry_interval": 0.1, + "max_retries": 10 + }, "redis": { "host": "localhost", "port": 6379, - "max_retries": 10, - "base_retry_interval": 0.1 + "initial_retry_interval": 0.1, + "max_retries": 10 }, "peer_rejoin": { "enable": False, diff --git a/maro/data_lib/cim/cim_data_container_helpers.py b/maro/data_lib/cim/cim_data_container_helpers.py index e2b2dcfec..f423bb7ea 100644 --- a/maro/data_lib/cim/cim_data_container_helpers.py +++ b/maro/data_lib/cim/cim_data_container_helpers.py @@ -30,6 +30,9 @@ def __init__(self, config_path: str, max_tick: int, topology: str): self._init_data_container() + self._random_seed: Optional[int] = None + self._re_init_data_cntr_flag: bool = False + def _init_data_container(self, topology_seed: int = None): if not os.path.exists(self._config_path): raise FileNotFoundError @@ -46,12 +49,22 @@ def _init_data_container(self, topology_seed: int = None): # Real Data Mode: read data from input data files, no need for any config.yml. self._data_cntr = data_from_files(data_folder=self._config_path) - def reset(self, keep_seed: bool): - """Reset data container internal state""" + def reset(self, keep_seed: bool) -> None: + """Reset data container internal state + """ if not keep_seed: - self._init_data_container(random[ROUTE_INIT_RAND_KEY].randint(0, DATA_CONTAINER_INIT_SEED_LIMIT - 1)) + self._random_seed = random[ROUTE_INIT_RAND_KEY].randint(0, DATA_CONTAINER_INIT_SEED_LIMIT - 1) + self._re_init_data_cntr_flag = True + + if self._re_init_data_cntr_flag: + self._init_data_container(self._random_seed) + self._re_init_data_cntr_flag = False else: - self._data_cntr.reset() + self._data_cntr.reset() # Reset the data container with reproduce-ability + + def set_seed(self, random_seed: int) -> None: + self._random_seed = random_seed + self._re_init_data_cntr_flag = True def __getattr__(self, name): return getattr(self._data_cntr, name) diff --git a/maro/rl/__init__.py b/maro/rl/__init__.py index 40571c71e..9a0454564 100644 --- a/maro/rl/__init__.py +++ b/maro/rl/__init__.py @@ -1,29 +1,2 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. - -from maro.rl.agent import ( - DDPG, DQN, AbsAgent, ActorCritic, ActorCriticConfig, DDPGConfig, DQNConfig, MultiAgentWrapper, PolicyGradient -) -from maro.rl.exploration import ( - AbsExplorer, EpsilonGreedyExplorer, GaussianNoiseExplorer, NoiseExplorer, UniformNoiseExplorer -) -from maro.rl.model import AbsBlock, AbsCoreModel, FullyConnectedBlock, OptimOption, SimpleMultiHeadModel -from maro.rl.scheduling import LinearParameterScheduler, Scheduler, TwoPhaseLinearParameterScheduler -from maro.rl.storage import AbsStore, OverwriteType, SimpleStore -from maro.rl.training import AbsLearner, Actor, ActorProxy, OffPolicyLearner, OnPolicyLearner, Trajectory -from maro.rl.utils import ( - ExperienceCollectionUtils, get_k_step_returns, get_lambda_returns, get_log_prob, get_max, - get_truncated_cumulative_reward, select_by_actions -) - -__all__ = [ - "AbsAgent", "ActorCritic", "ActorCriticConfig", "DDPG", "DDPGConfig", "DQN", "DQNConfig", "MultiAgentWrapper", - "PolicyGradient", - "AbsExplorer", "EpsilonGreedyExplorer", "GaussianNoiseExplorer", "NoiseExplorer", "UniformNoiseExplorer", - "AbsBlock", "AbsCoreModel", "FullyConnectedBlock", "OptimOption", "SimpleMultiHeadModel", - "LinearParameterScheduler", "Scheduler", "TwoPhaseLinearParameterScheduler", - "AbsStore", "OverwriteType", "SimpleStore", - "AbsLearner", "Actor", "ActorProxy", "OffPolicyLearner", "OnPolicyLearner", "Trajectory", - "ExperienceCollectionUtils", "get_k_step_returns", "get_lambda_returns", "get_log_prob", "get_max", - "get_truncated_cumulative_reward", "select_by_actions" -] diff --git a/maro/rl/agent/__init__.py b/maro/rl/agent/__init__.py deleted file mode 100644 index 137f0a66b..000000000 --- a/maro/rl/agent/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from .abs_agent import AbsAgent -from .ac import ActorCritic, ActorCriticConfig -from .agent_wrapper import MultiAgentWrapper -from .ddpg import DDPG, DDPGConfig -from .dqn import DQN, DQNConfig -from .pg import PolicyGradient - -__all__ = [ - "AbsAgent", - "ActorCritic", "ActorCriticConfig", - "MultiAgentWrapper", - "DDPG", "DDPGConfig", - "DQN", "DQNConfig", - "PolicyGradient" -] diff --git a/maro/rl/agent/abs_agent.py b/maro/rl/agent/abs_agent.py deleted file mode 100644 index 377f88a7e..000000000 --- a/maro/rl/agent/abs_agent.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from abc import ABC, abstractmethod - -import torch - -from maro.rl.model import AbsCoreModel - - -class AbsAgent(ABC): - """Abstract RL agent class. - - It's a sandbox for the RL algorithm. Scenario-specific details will be excluded. - We focus on the abstraction algorithm development here. Environment observation and decision events will - be converted to a uniform format before calling in. The output will be converted to an environment - executable format before return back to the environment. Its key responsibility is optimizing policy based - on interaction with the environment. - - Args: - model (AbsCoreModel): Task model or container of task models required by the algorithm. - config: Settings for the algorithm. - """ - def __init__(self, model: AbsCoreModel, config): - self.model = model - self.config = config - self.device = None - - def to_device(self, device): - self.device = device - self.model = self.model.to(device) - - @abstractmethod - def choose_action(self, state): - """This method uses the underlying model(s) to compute an action from a shaped state. - - Args: - state: A state object shaped by a ``StateShaper`` to conform to the model input format. - - Returns: - The action to be taken given ``state``. It is usually necessary to use an ``ActionShaper`` to convert - this to an environment executable action. - """ - return NotImplementedError - - def set_exploration_params(self, **params): - pass - - @abstractmethod - def learn(self, *args, **kwargs): - """Algorithm-specific training logic. - - The parameters are data to train the underlying model on. Algorithm-specific loss and optimization - should be reflected here. - """ - return NotImplementedError - - def load_model(self, model): - """Load models from memory.""" - self.model.load_state_dict(model) - - def dump_model(self): - """Return the algorithm's trainable models.""" - return self.model.state_dict() - - def load_model_from_file(self, path: str): - """Load trainable models from disk. - - Load trainable models from the specified directory. The model file is always prefixed with the agent's name. - - Args: - path (str): path to the directory where the models are saved. - """ - self.model.load_state_dict(torch.load(path)) - - def dump_model_to_file(self, path: str): - """Dump the algorithm's trainable models to disk. - - Dump trainable models to the specified directory. The model file is always prefixed with the agent's name. - - Args: - path (str): path to the directory where the models are saved. - """ - torch.save(self.model.state_dict(), path) diff --git a/maro/rl/agent/ac.py b/maro/rl/agent/ac.py deleted file mode 100644 index 0c80a6a6e..000000000 --- a/maro/rl/agent/ac.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from typing import Callable, Tuple - -import numpy as np -import torch -from torch.distributions import Categorical -from torch.nn import MSELoss - -from maro.rl.model import SimpleMultiHeadModel -from maro.rl.utils import get_lambda_returns, get_log_prob -from maro.utils.exception.rl_toolkit_exception import UnrecognizedTask - -from .abs_agent import AbsAgent - - -class ActorCriticConfig: - """Configuration for the Actor-Critic algorithm. - - Args: - reward_discount (float): Reward decay as defined in standard RL terminology. - critic_loss_func (Callable): Loss function for the critic model. - train_iters (int): Number of gradient descent steps per call to ``train``. - actor_loss_coefficient (float): The coefficient for actor loss in the total loss function, e.g., - loss = critic_loss + ``actor_loss_coefficient`` * actor_loss. Defaults to 1.0. - k (int): Number of time steps used in computing returns or return estimates. Defaults to -1, in which case - rewards are accumulated until the end of the trajectory. - lam (float): Lambda coefficient used in computing lambda returns. Defaults to 1.0, in which case the usual - k-step return is computed. - clip_ratio (float): Clip ratio in the PPO algorithm (https://arxiv.org/pdf/1707.06347.pdf). Defaults to None, - in which case the actor loss is calculated using the usual policy gradient theorem. - """ - __slots__ = [ - "reward_discount", "critic_loss_func", "train_iters", "actor_loss_coefficient", "k", "lam", "clip_ratio" - ] - - def __init__( - self, - reward_discount: float, - train_iters: int, - critic_loss_func: Callable = MSELoss(), - actor_loss_coefficient: float = 1.0, - k: int = -1, - lam: float = 1.0, - clip_ratio: float = None - ): - self.reward_discount = reward_discount - self.critic_loss_func = critic_loss_func - self.train_iters = train_iters - self.actor_loss_coefficient = actor_loss_coefficient - self.k = k - self.lam = lam - self.clip_ratio = clip_ratio - - -class ActorCritic(AbsAgent): - """Actor Critic algorithm with separate policy and value models. - - References: - https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch. - https://towardsdatascience.com/understanding-actor-critic-methods-931b97b6df3f - - Args: - model (SimpleMultiHeadModel): Multi-task model that computes action distributions and state values. - It may or may not have a shared bottom stack. - config: Configuration for the AC algorithm. - """ - def __init__(self, model: SimpleMultiHeadModel, config: ActorCriticConfig): - if model.task_names is None or set(model.task_names) != {"actor", "critic"}: - raise UnrecognizedTask(f"Expected model task names 'actor' and 'critic', but got {model.task_names}") - super().__init__(model, config) - - def choose_action(self, state: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - """Use the actor (policy) model to generate stochastic actions. - - Args: - state: Input to the actor model. - - Returns: - Actions and corresponding log probabilities. - """ - state = torch.from_numpy(state).to(self.device) - is_single = len(state.shape) == 1 - if is_single: - state = state.unsqueeze(dim=0) - - action_prob = Categorical(self.model(state, task_name="actor", training=False)) - action = action_prob.sample() - log_p = action_prob.log_prob(action) - action, log_p = action.cpu().numpy(), log_p.cpu().numpy() - return (action[0], log_p[0]) if is_single else (action, log_p) - - def learn( - self, states: np.ndarray, actions: np.ndarray, log_p: np.ndarray, rewards: np.ndarray - ): - states = torch.from_numpy(states).to(self.device) - actions = torch.from_numpy(actions).to(self.device) - log_p = torch.from_numpy(log_p).to(self.device) - rewards = torch.from_numpy(rewards).to(self.device) - - state_values = self.model(states, task_name="critic").detach().squeeze() - return_est = get_lambda_returns( - rewards, state_values, self.config.reward_discount, self.config.lam, k=self.config.k - ) - advantages = return_est - state_values - - for i in range(self.config.train_iters): - # actor loss - log_p_new = get_log_prob(self.model(states, task_name="actor"), actions) - if self.config.clip_ratio is not None: - ratio = torch.exp(log_p_new - log_p) - clip_ratio = torch.clamp(ratio, 1 - self.config.clip_ratio, 1 + self.config.clip_ratio) - actor_loss = -(torch.min(ratio * advantages, clip_ratio * advantages)).mean() - else: - actor_loss = -(log_p_new * advantages).mean() - - # critic_loss - state_values = self.model(states, task_name="critic").squeeze() - critic_loss = self.config.critic_loss_func(state_values, return_est) - loss = critic_loss + self.config.actor_loss_coefficient * actor_loss - self.model.step(loss) diff --git a/maro/rl/agent/agent_wrapper.py b/maro/rl/agent/agent_wrapper.py deleted file mode 100644 index 26f18b028..000000000 --- a/maro/rl/agent/agent_wrapper.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import os -from typing import List, Union - - -class MultiAgentWrapper: - """Multi-agent wrapper class that exposes the same interfaces as a single agent.""" - def __init__(self, agent_dict: dict): - self.agent_dict = agent_dict - - def __getitem__(self, agent_id: str): - return self.agent_dict[agent_id] - - def choose_action(self, state_by_agent: dict): - return {agent_id: self.agent_dict[agent_id].choose_action(state) for agent_id, state in state_by_agent.items()} - - def set_exploration_params(self, params): - # Per-agent exploration parameters - if isinstance(params, dict) and params.keys() <= self.agent_dict.keys(): - for agent_id, params in params.items(): - self.agent_dict[agent_id].set_exploration_params(**params) - # Shared exploration parameters for all agents - else: - for agent in self.agent_dict.values(): - agent.set_exploration_params(**params) - - def load_model(self, model_dict: dict): - """Load models from memory for each agent.""" - for agent_id, model in model_dict.items(): - self.agent_dict[agent_id].load_model(model) - - def dump_model(self, agent_ids: Union[str, List[str]] = None): - """Get agents' underlying models. - - This is usually used in distributed mode where models need to be broadcast to remote roll-out actors. - """ - if agent_ids is None: - return {agent_id: agent.dump_model() for agent_id, agent in self.agent_dict.items()} - elif isinstance(agent_ids, str): - return self.agent_dict[agent_ids].dump_model() - else: - return {agent_id: self.agent_dict[agent_id].dump_model() for agent_id in self.agent_dict} - - def load_model_from_file(self, dir_path): - """Load models from disk for each agent.""" - for agent_id, agent in self.agent_dict.items(): - agent.load_model_from_file(os.path.join(dir_path, agent_id)) - - def dump_model_to_file(self, dir_path: str, agent_ids: Union[str, List[str]] = None): - """Dump agents' models to disk. - - Each agent will use its own name to create a separate file under ``dir_path`` for dumping. - """ - os.makedirs(dir_path, exist_ok=True) - if agent_ids is None: - for agent_id, agent in self.agent_dict.items(): - agent.dump_model_to_file(os.path.join(dir_path, agent_id)) - elif isinstance(agent_ids, str): - self.agent_dict[agent_ids].dump_model_to_file(os.path.join(dir_path, agent_ids)) - else: - for agent_id in agent_ids: - self.agent_dict[agent_id].dump_model_to_file(os.path.join(dir_path, agent_id)) diff --git a/maro/rl/agent/ddpg.py b/maro/rl/agent/ddpg.py deleted file mode 100644 index a4be15e85..000000000 --- a/maro/rl/agent/ddpg.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from typing import Callable, Union - -import numpy as np -import torch - -from maro.rl.exploration import NoiseExplorer -from maro.rl.model import SimpleMultiHeadModel -from maro.utils.exception.rl_toolkit_exception import UnrecognizedTask - -from .abs_agent import AbsAgent - - -class DDPGConfig: - """Configuration for the DDPG algorithm. - Args: - reward_discount (float): Reward decay as defined in standard RL terminology. - q_value_loss_func (Callable): Loss function for the Q-value estimator. - target_update_freq (int): Number of training rounds between policy target model updates. - actor_loss_coefficient (float): The coefficient for policy loss in the total loss function, e.g., - loss = q_value_loss + ``policy_loss_coefficient`` * policy_loss. Defaults to 1.0. - tau (float): Soft update coefficient, e.g., target_model = tau * eval_model + (1-tau) * target_model. - Defaults to 1.0. - """ - __slots__ = ["reward_discount", "q_value_loss_func", "target_update_freq", "policy_loss_coefficient", "tau"] - - def __init__( - self, - reward_discount: float, - q_value_loss_func: Callable, - target_update_freq: int, - policy_loss_coefficient: float = 1.0, - tau: float = 1.0, - ): - self.reward_discount = reward_discount - self.q_value_loss_func = q_value_loss_func - self.target_update_freq = target_update_freq - self.policy_loss_coefficient = policy_loss_coefficient - self.tau = tau - - -class DDPG(AbsAgent): - """The Deep Deterministic Policy Gradient (DDPG) algorithm. - - References: - https://arxiv.org/pdf/1509.02971.pdf - https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/ddpg - - Args: - model (SimpleMultiHeadModel): DDPG policy and q-value models. - config: Configuration for DDPG algorithm. - explorer (NoiseExplorer): An NoiseExplorer instance for generating exploratory actions. Defaults to None. - """ - def __init__(self, model: SimpleMultiHeadModel, config: DDPGConfig, explorer: NoiseExplorer = None): - if model.task_names is None or set(model.task_names) != {"policy", "q_value"}: - raise UnrecognizedTask(f"Expected model task names 'policy' and 'q_value', but got {model.task_names}") - super().__init__(model, config) - self._explorer = explorer - self._target_model = model.copy() if model.trainable else None - self._train_cnt = 0 - - def choose_action(self, state) -> Union[float, np.ndarray]: - state = torch.from_numpy(state).to(self.device) - is_single = len(state.shape) == 1 - if is_single: - state = state.unsqueeze(dim=0) - - action = self.model(state, task_name="policy", training=False).data.cpu().numpy() - action_dim = action.shape[1] - if self._explorer: - action = self._explorer(action) - - if action_dim == 1: - action = action.squeeze(axis=1) - - return action[0] if is_single else action - - def learn(self, states: np.ndarray, actions: np.ndarray, rewards: np.ndarray, next_states: np.ndarray): - states = torch.from_numpy(states).to(self.device) - actual_actions = torch.from_numpy(actions).to(self.device) - rewards = torch.from_numpy(rewards).to(self.device) - next_states = torch.from_numpy(next_states).to(self.device) - if len(actual_actions.shape) == 1: - actual_actions = actual_actions.unsqueeze(dim=1) # (N, 1) - - current_q_values = self.model(torch.cat([states, actual_actions], dim=1), task_name="q_value") - current_q_values = current_q_values.squeeze(dim=1) # (N,) - next_actions = self._target_model(states, task_name="policy", training=False) - next_q_values = self._target_model( - torch.cat([next_states, next_actions], dim=1), task_name="q_value", training=False - ).squeeze(1) # (N,) - target_q_values = (rewards + self.config.reward_discount * next_q_values).detach() # (N,) - q_value_loss = self.config.q_value_loss_func(current_q_values, target_q_values) - actions_from_model = self.model(states, task_name="policy") - policy_loss = -self.model(torch.cat([states, actions_from_model], dim=1), task_name="q_value").mean() - self.model.learn(q_value_loss + self.config.policy_loss_coefficient * policy_loss) - self._train_cnt += 1 - if self._train_cnt % self.config.target_update_freq == 0: - self._target_model.soft_update(self.model, self.config.tau) diff --git a/maro/rl/agent/dqn.py b/maro/rl/agent/dqn.py deleted file mode 100644 index 86b33c129..000000000 --- a/maro/rl/agent/dqn.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from typing import Union - -import numpy as np -import torch - -from maro.rl.model import SimpleMultiHeadModel -from maro.rl.utils import get_max, get_td_errors, select_by_actions -from maro.utils.exception.rl_toolkit_exception import UnrecognizedTask - -from .abs_agent import AbsAgent - - -class DQNConfig: - """Configuration for the DQN algorithm. - - Args: - reward_discount (float): Reward decay as defined in standard RL terminology. - epsilon (float): Exploration rate for epsilon-greedy exploration. Defaults to None. - tau (float): Soft update coefficient, i.e., target_model = tau * eval_model + (1 - tau) * target_model. - double (bool): If True, the next Q values will be computed according to the double DQN algorithm, - i.e., q_next = Q_target(s, argmax(Q_eval(s, a))). Otherwise, q_next = max(Q_target(s, a)). - See https://arxiv.org/pdf/1509.06461.pdf for details. Defaults to False. - advantage_type (str): Advantage mode for the dueling architecture. Defaults to None, in which - case it is assumed that the regular Q-value model is used. - loss_cls: Loss function class for evaluating TD errors. Defaults to torch.nn.MSELoss. - target_update_freq (int): Number of training rounds between target model updates. - """ - __slots__ = [ - "reward_discount", "target_update_freq", "epsilon", "tau", "double", "advantage_type", "loss_func" - ] - - def __init__( - self, - reward_discount: float, - target_update_freq: int, - epsilon: float = .0, - tau: float = 0.1, - double: bool = True, - advantage_type: str = None, - loss_cls=torch.nn.MSELoss - ): - self.reward_discount = reward_discount - self.target_update_freq = target_update_freq - self.epsilon = epsilon - self.tau = tau - self.double = double - self.advantage_type = advantage_type - self.loss_func = loss_cls(reduction="none") - - -class DQN(AbsAgent): - """The Deep-Q-Networks algorithm. - - See https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf for details. - - Args: - model (SimpleMultiHeadModel): Q-value model. - config: Configuration for DQN algorithm. - """ - def __init__(self, model: SimpleMultiHeadModel, config: DQNConfig): - if (config.advantage_type is not None and - (model.task_names is None or set(model.task_names) != {"state_value", "advantage"})): - raise UnrecognizedTask( - f"Expected model task names 'state_value' and 'advantage' since dueling DQN is used, " - f"got {model.task_names}" - ) - super().__init__(model, config) - self._training_counter = 0 - self._target_model = model.copy() if model.trainable else None - - def choose_action(self, state: np.ndarray) -> Union[int, np.ndarray]: - state = torch.from_numpy(state) - if self.device: - state = state.to(self.device) - is_single = len(state.shape) == 1 - if is_single: - state = state.unsqueeze(dim=0) - - q_values = self._get_q_values(state, training=False) - num_actions = q_values.shape[1] - greedy_action = q_values.argmax(dim=1).data.cpu() - # No exploration - if self.config.epsilon == .0: - return greedy_action.item() if is_single else greedy_action.numpy() - - if is_single: - return greedy_action if np.random.random() > self.config.epsilon else np.random.choice(num_actions) - - # batch inference - return np.array([ - act if np.random.random() > self.config.epsilon else np.random.choice(num_actions) - for act in greedy_action - ]) - - def learn(self, states: np.ndarray, actions: np.ndarray, rewards: np.ndarray, next_states: np.ndarray): - states = torch.from_numpy(states) - actions = torch.from_numpy(actions) - rewards = torch.from_numpy(rewards) - next_states = torch.from_numpy(next_states) - - if self.device: - states = states.to(self.device) - actions = actions.to(self.device) - rewards = rewards.to(self.device) - next_states = next_states.to(self.device) - - q_all = self._get_q_values(states) - q = select_by_actions(q_all, actions) - next_q_all_target = self._get_q_values(next_states, is_eval=False, training=False) - if self.config.double: - next_q_all_eval = self._get_q_values(next_states, training=False) - next_q = select_by_actions(next_q_all_target, next_q_all_eval.max(dim=1)[1]) # (N,) - else: - next_q, _ = get_max(next_q_all_target) # (N,) - - loss = get_td_errors(q, next_q, rewards, self.config.reward_discount, loss_func=self.config.loss_func) - self.model.step(loss.mean()) - self._training_counter += 1 - if self._training_counter % self.config.target_update_freq == 0: - self._target_model.soft_update(self.model, self.config.tau) - - return loss.detach().numpy() - - def set_exploration_params(self, epsilon): - self.config.epsilon = epsilon - - def _get_q_values(self, states: torch.Tensor, is_eval: bool = True, training: bool = True): - output = self.model(states, training=training) if is_eval else self._target_model(states, training=False) - if self.config.advantage_type is None: - return output - else: - state_values = output["state_value"] - advantages = output["advantage"] - # Use mean or max correction to address the identifiability issue - corrections = advantages.mean(1) if self.config.advantage_type == "mean" else advantages.max(1)[0] - return state_values + advantages - corrections.unsqueeze(1) diff --git a/maro/rl/agent/pg.py b/maro/rl/agent/pg.py deleted file mode 100644 index b58acbe09..000000000 --- a/maro/rl/agent/pg.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from typing import Tuple - -import numpy as np -import torch -from torch.distributions import Categorical - -from maro.rl.model import SimpleMultiHeadModel -from maro.rl.utils import get_truncated_cumulative_reward - -from .abs_agent import AbsAgent - - -class PolicyGradient(AbsAgent): - """The vanilla Policy Gradient (VPG) algorithm, a.k.a., REINFORCE. - - Reference: https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch. - - Args: - model (SimpleMultiHeadModel): Model that computes action distributions. - reward_discount (float): Reward decay as defined in standard RL terminology. - """ - def __init__(self, model: SimpleMultiHeadModel, reward_discount: float): - super().__init__(model, reward_discount) - - def choose_action(self, state: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - """Use the actor (policy) model to generate stochastic actions. - - Args: - state: Input to the actor model. - - Returns: - Actions and corresponding log probabilities. - """ - state = torch.from_numpy(state).to(self.device) - is_single = len(state.shape) == 1 - if is_single: - state = state.unsqueeze(dim=0) - - action_prob = Categorical(self.model(state, training=False)) - action = action_prob.sample() - log_p = action_prob.log_prob(action) - action, log_p = action.cpu().numpy(), log_p.cpu().numpy() - return (action[0], log_p[0]) if is_single else (action, log_p) - - def learn(self, states: np.ndarray, actions: np.ndarray, rewards: np.ndarray): - states = torch.from_numpy(states).to(self.device) - actions = torch.from_numpy(actions).to(self.device) - returns = get_truncated_cumulative_reward(rewards, self.config) - returns = torch.from_numpy(returns).to(self.device) - action_distributions = self.model(states) - action_prob = action_distributions.gather(1, actions.unsqueeze(1)).squeeze() # (N, 1) - loss = -(torch.log(action_prob) * returns).mean() - self.model.step(loss) diff --git a/maro/rl/distributed/__init__.py b/maro/rl/distributed/__init__.py new file mode 100644 index 000000000..cd9276166 --- /dev/null +++ b/maro/rl/distributed/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from .abs_proxy import AbsProxy +from .abs_worker import AbsWorker + +__all__ = [ + "AbsProxy", "AbsWorker", +] diff --git a/maro/rl/distributed/abs_proxy.py b/maro/rl/distributed/abs_proxy.py new file mode 100644 index 000000000..31002f0f4 --- /dev/null +++ b/maro/rl/distributed/abs_proxy.py @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import abstractmethod + +import zmq +from tornado.ioloop import IOLoop +from zmq import Context +from zmq.eventloop.zmqstream import ZMQStream + +from maro.rl.utils.common import get_own_ip_address + + +class AbsProxy(object): + """Abstract proxy class that serves as an intermediary between task producers and task consumers. + + The proxy receives compute tasks from multiple clients, forwards them to a set of back-end workers for + processing and returns the results to the clients. + + Args: + frontend_port (int): Network port for communicating with clients (task producers). + backend_port (int): Network port for communicating with back-end workers (task consumers). + """ + + def __init__(self, frontend_port: int, backend_port: int) -> None: + super(AbsProxy, self).__init__() + + # ZMQ sockets and streams + self._context = Context.instance() + self._req_socket = self._context.socket(zmq.ROUTER) + self._ip_address = get_own_ip_address() + self._req_socket.bind(f"tcp://{self._ip_address}:{frontend_port}") + self._req_endpoint = ZMQStream(self._req_socket) + self._dispatch_socket = self._context.socket(zmq.ROUTER) + self._dispatch_socket.bind(f"tcp://{self._ip_address}:{backend_port}") + self._dispatch_endpoint = ZMQStream(self._dispatch_socket) + self._event_loop = IOLoop.current() + + # register handlers + self._dispatch_endpoint.on_recv(self._send_result_to_requester) + + @abstractmethod + def _route_request_to_compute_node(self, msg: list) -> None: + """Dispatch the task to one or more workers for processing. + + The dispatching strategy should be implemented here. + + Args: + msg (list): Multi-part message containing task specifications and parameters. + """ + raise NotImplementedError + + @abstractmethod + def _send_result_to_requester(self, msg: list) -> None: + """Return a task result to the client that requested it. + + The result aggregation logic, if applicable, should be implemented here. + + Args: + msg (list): Multi-part message containing a task result. + """ + raise NotImplementedError + + def start(self) -> None: + """Start a Tornado event loop. + + Calling this enters the proxy into an event loop where it starts doing its job. + """ + self._event_loop.start() + + def stop(self) -> None: + """Stop the currently running event loop. + """ + self._event_loop.stop() diff --git a/maro/rl/distributed/abs_worker.py b/maro/rl/distributed/abs_worker.py new file mode 100644 index 000000000..9baafb07a --- /dev/null +++ b/maro/rl/distributed/abs_worker.py @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import abstractmethod + +import zmq +from tornado.ioloop import IOLoop +from zmq import Context +from zmq.eventloop.zmqstream import ZMQStream + +from maro.rl.utils.common import get_ip_address_by_hostname, string_to_bytes +from maro.utils import DummyLogger, LoggerV2 + + +class AbsWorker(object): + """Abstract worker class to process a task in distributed fashion. + + Args: + idx (int): Integer identifier for the worker. It is used to generate an internal ID, "worker.{idx}", + so that the task producer can keep track of its connection status. + producer_host (str): IP address of the task producer host to connect to. + producer_port (int): Port of the task producer host to connect to. + logger (Logger, default=None): The logger of the workflow. + """ + + def __init__( + self, + idx: int, + producer_host: str, + producer_port: int, + logger: LoggerV2 = None, + ) -> None: + super(AbsWorker, self).__init__() + + self._id = f"worker.{idx}" + self._logger = logger if logger else DummyLogger() + + # ZMQ sockets and streams + self._context = Context.instance() + self._socket = self._context.socket(zmq.DEALER) + self._socket.identity = string_to_bytes(self._id) + + self._producer_ip = get_ip_address_by_hostname(producer_host) + self._producer_address = f"tcp://{self._producer_ip}:{producer_port}" + self._socket.connect(self._producer_address) + self._logger.info(f"Connected to producer at {self._producer_address}") + + self._stream = ZMQStream(self._socket) + self._stream.send(b"READY") + + self._event_loop = IOLoop.current() + + # register handlers + self._stream.on_recv(self._compute) + + @abstractmethod + def _compute(self, msg: list) -> None: + """The task processing logic should be implemented here. + + Args: + msg (list): Multi-part message containing task specifications and parameters. + """ + raise NotImplementedError + + def start(self) -> None: + """Start a Tornado event loop. + + Calling this enters the worker into an event loop where it starts doing its job. + """ + self._event_loop.start() + + def stop(self) -> None: + """Stop the currently running event loop. + """ + self._event_loop.stop() diff --git a/maro/rl/exploration/__init__.py b/maro/rl/exploration/__init__.py index e4a94ff21..7e60b56c5 100644 --- a/maro/rl/exploration/__init__.py +++ b/maro/rl/exploration/__init__.py @@ -1,8 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from .abs_explorer import AbsExplorer -from .epsilon_greedy_explorer import EpsilonGreedyExplorer -from .noise_explorer import GaussianNoiseExplorer, NoiseExplorer, UniformNoiseExplorer +from .scheduling import AbsExplorationScheduler, LinearExplorationScheduler, MultiLinearExplorationScheduler +from .strategies import epsilon_greedy, gaussian_noise, uniform_noise -__all__ = ["AbsExplorer", "EpsilonGreedyExplorer", "GaussianNoiseExplorer", "NoiseExplorer", "UniformNoiseExplorer"] +__all__ = [ + "AbsExplorationScheduler", "LinearExplorationScheduler", "MultiLinearExplorationScheduler", + "epsilon_greedy", "gaussian_noise", "uniform_noise", +] diff --git a/maro/rl/exploration/abs_explorer.py b/maro/rl/exploration/abs_explorer.py deleted file mode 100644 index 40558b263..000000000 --- a/maro/rl/exploration/abs_explorer.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from abc import ABC, abstractmethod - - -class AbsExplorer(ABC): - """Abstract explorer class for generating exploration rates. - - """ - def __init__(self): - pass - - @abstractmethod - def set_parameters(self, **exploration_params): - return NotImplementedError - - @abstractmethod - def __call__(self, action): - return NotImplementedError diff --git a/maro/rl/exploration/epsilon_greedy_explorer.py b/maro/rl/exploration/epsilon_greedy_explorer.py deleted file mode 100644 index 5c9463140..000000000 --- a/maro/rl/exploration/epsilon_greedy_explorer.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from typing import Union - -import numpy as np - -from .abs_explorer import AbsExplorer - - -class EpsilonGreedyExplorer(AbsExplorer): - """Epsilon greedy explorer for discrete action spaces. - - Args: - num_actions (int): Number of all possible actions. - """ - def __init__(self, num_actions: int, epsilon: float = .0): - super().__init__() - self._num_actions = num_actions - self._epsilon = epsilon - - def __call__(self, action_index: Union[int, np.ndarray]): - if isinstance(action_index, np.ndarray): - return [self._get_exploration_action(act) for act in action_index] - else: - return self._get_exploration_action(action_index) - - def set_parameters(self, *, epsilon: float): - self._epsilon = epsilon - - def _get_exploration_action(self, action_index): - assert (action_index < self._num_actions), f"Invalid action: {action_index}" - return action_index if np.random.random() > self._epsilon else np.random.choice(self._num_actions) diff --git a/maro/rl/exploration/noise_explorer.py b/maro/rl/exploration/noise_explorer.py deleted file mode 100644 index 999994cb4..000000000 --- a/maro/rl/exploration/noise_explorer.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from abc import abstractmethod -from typing import Union - -import numpy as np - -from .abs_explorer import AbsExplorer - - -class NoiseExplorer(AbsExplorer): - """Explorer that adds a random noise to a model-generated action.""" - def __init__( - self, - min_action: Union[float, list, np.ndarray] = None, - max_action: Union[float, list, np.ndarray] = None - ): - if isinstance(min_action, (list, np.ndarray)) and isinstance(max_action, (list, np.ndarray)): - assert len(min_action) == len(max_action), "min_action and max_action should have the same dimension." - super().__init__() - self._min_action = min_action - self._max_action = max_action - - @abstractmethod - def set_parameters(self, **parameters): - raise NotImplementedError - - @abstractmethod - def __call__(self, action) -> np.ndarray: - raise NotImplementedError - - -class UniformNoiseExplorer(NoiseExplorer): - """Explorer that adds a random noise to a model-generated action sampled from a uniform distribution.""" - def __init__( - self, - min_action: Union[float, list, np.ndarray] = None, - max_action: Union[float, list, np.ndarray] = None, - noise_lower_bound: Union[float, list, np.ndarray] = .0, - noise_upper_bound: Union[float, list, np.ndarray] = .0 - ): - if isinstance(noise_upper_bound, (list, np.ndarray)) and isinstance(noise_upper_bound, (list, np.ndarray)): - assert len(noise_lower_bound) == len(noise_upper_bound), \ - "noise_lower_bound and noise_upper_bound should have the same dimension." - super().__init__(min_action, max_action) - self._noise_lower_bound = noise_lower_bound - self._noise_upper_bound = noise_upper_bound - - def set_parameters(self, *, noise_lower_bound, noise_upper_bound): - self._noise_lower_bound = noise_lower_bound - self._noise_upper_bound = noise_upper_bound - - def __call__(self, action: np.ndarray) -> np.ndarray: - return np.array([self._get_exploration_action(act) for act in action]) - - def _get_exploration_action(self, action): - action += np.random.uniform(self._noise_lower_bound, self._noise_upper_bound) - if self._min_action is not None or self._max_action is not None: - return np.clip(action, self._min_action, self._max_action) - else: - return action - - -class GaussianNoiseExplorer(NoiseExplorer): - """Explorer that adds a random noise to a model-generated action sampled from a Gaussian distribution.""" - def __init__( - self, - min_action: Union[float, list, np.ndarray] = None, - max_action: Union[float, list, np.ndarray] = None, - noise_mean: Union[float, list, np.ndarray] = .0, - noise_stddev: Union[float, list, np.ndarray] = .0, - is_relative: bool = False - ): - if isinstance(noise_mean, (list, np.ndarray)) and isinstance(noise_stddev, (list, np.ndarray)): - assert len(noise_mean) == len(noise_stddev), "noise_mean and noise_stddev should have the same dimension." - if is_relative and noise_mean != .0: - raise ValueError("Standard deviation cannot be relative if noise mean is non-zero.") - super().__init__(min_action, max_action) - self._noise_mean = noise_mean - self._noise_stddev = noise_stddev - self._is_relative = is_relative - - def set_parameters(self, *, noise_stddev, noise_mean=.0): - self._noise_stddev = noise_stddev - self._noise_mean = noise_mean - - def __call__(self, action: np.ndarray) -> np.ndarray: - return np.array([self._get_exploration_action(act) for act in action]) - - def _get_exploration_action(self, action): - noise = np.random.normal(loc=self._noise_mean, scale=self._noise_stddev) - action += (noise * action) if self._is_relative else noise - if self._min_action is not None or self._max_action is not None: - return np.clip(action, self._min_action, self._max_action) - else: - return action diff --git a/maro/rl/exploration/scheduling.py b/maro/rl/exploration/scheduling.py new file mode 100644 index 000000000..ce4dd930f --- /dev/null +++ b/maro/rl/exploration/scheduling.py @@ -0,0 +1,127 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import ABC, abstractmethod +from typing import List, Tuple + + +class AbsExplorationScheduler(ABC): + """Abstract exploration scheduler. + + Args: + exploration_params (dict): The exploration params attribute from some ``RLPolicy`` instance to which the + scheduler is applied. + param_name (str): Name of the exploration parameter to which the scheduler is applied. + initial_value (float, default=None): Initial value for the exploration parameter. If None, the value used + when instantiating the policy will be used as the initial value. + """ + + def __init__(self, exploration_params: dict, param_name: str, initial_value: float = None) -> None: + super().__init__() + self._exploration_params = exploration_params + self.param_name = param_name + if initial_value is not None: + self._exploration_params[self.param_name] = initial_value + + def get_value(self) -> float: + return self._exploration_params[self.param_name] + + @abstractmethod + def step(self) -> None: + raise NotImplementedError + + +class LinearExplorationScheduler(AbsExplorationScheduler): + """Linear exploration parameter schedule. + + Args: + exploration_params (dict): The exploration params attribute from some ``RLPolicy`` instance to which the + scheduler is applied. + param_name (str): Name of the exploration parameter to which the scheduler is applied. + last_ep (int): Last episode. + final_value (float): The value of the exploration parameter corresponding to ``last_ep``. + start_ep (int, default=1): starting episode. + initial_value (float, default=None): Initial value for the exploration parameter. If None, the value used + when instantiating the policy will be used as the initial value. + """ + + def __init__( + self, + exploration_params: dict, + param_name: str, + *, + last_ep: int, + final_value: float, + start_ep: int = 1, + initial_value: float = None, + ) -> None: + super().__init__(exploration_params, param_name, initial_value=initial_value) + self.final_value = final_value + if last_ep > 1: + self.delta = (self.final_value - self._exploration_params[self.param_name]) / (last_ep - start_ep) + else: + self.delta = 0 + + def step(self) -> None: + if self._exploration_params[self.param_name] == self.final_value: + return + + self._exploration_params[self.param_name] += self.delta + + +class MultiLinearExplorationScheduler(AbsExplorationScheduler): + """Exploration parameter schedule that consists of multiple linear phases. + + Args: + exploration_params (dict): The exploration params attribute from some ``RLPolicy`` instance to which the + scheduler is applied. + param_name (str): Name of the exploration parameter to which the scheduler is applied. + splits (List[Tuple[int, float]]): List of points that separate adjacent linear phases. Each + point is a (episode, parameter_value) tuple that indicates the end of one linear phase and + the start of another. These points do not have to be given in any particular order. There + cannot be two points with the same first element (episode), or a ``ValueError`` will be raised. + last_ep (int): Last episode. + final_value (float): The value of the exploration parameter corresponding to ``last_ep``. + start_ep (int, default=1): starting episode. + initial_value (float, default=None): Initial value for the exploration parameter. If None, the value from + the original dictionary the policy is instantiated with will be used as the initial value. + """ + + def __init__( + self, + exploration_params: dict, + param_name: str, + *, + splits: List[Tuple[int, float]], + last_ep: int, + final_value: float, + start_ep: int = 1, + initial_value: float = None, + ) -> None: + # validate splits + splits = [(start_ep, initial_value)] + splits + [(last_ep, final_value)] + splits.sort() + for (ep, _), (ep2, _) in zip(splits, splits[1:]): + if ep == ep2: + raise ValueError("The zeroth element of split points must be unique") + + super().__init__(exploration_params, param_name, initial_value=initial_value) + self.final_value = final_value + self._splits = splits + self._ep = start_ep + self._split_index = 1 + self._delta = (self._splits[1][1] - self._exploration_params[self.param_name]) / (self._splits[1][0] - start_ep) + + def step(self) -> None: + if self._split_index == len(self._splits): + return + + self._exploration_params[self.param_name] += self._delta + self._ep += 1 + if self._ep == self._splits[self._split_index][0]: + self._split_index += 1 + if self._split_index < len(self._splits): + self._delta = ( + (self._splits[self._split_index][1] - self._splits[self._split_index - 1][1]) / + (self._splits[self._split_index][0] - self._splits[self._split_index - 1][0]) + ) diff --git a/maro/rl/exploration/strategies.py b/maro/rl/exploration/strategies.py new file mode 100644 index 000000000..c85340c78 --- /dev/null +++ b/maro/rl/exploration/strategies.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import Union + +import numpy as np + + +def epsilon_greedy( + state: np.ndarray, + action: np.ndarray, + num_actions: int, + *, + epsilon: float, +) -> np.ndarray: + """Epsilon-greedy exploration. + + Args: + state (np.ndarray): State(s) based on which ``action`` is chosen. This is not used by the vanilla + eps-greedy exploration and is put here to conform to the function signature required for the exploration + strategy parameter for ``DQN``. + action (np.ndarray): Action(s) chosen greedily by the policy. + num_actions (int): Number of possible actions. + epsilon (float): The probability that a random action will be selected. + + Returns: + Exploratory actions. + """ + return np.array([act if np.random.random() > epsilon else np.random.randint(num_actions) for act in action]) + + +def uniform_noise( + state: np.ndarray, + action: np.ndarray, + min_action: Union[float, list, np.ndarray] = None, + max_action: Union[float, list, np.ndarray] = None, + *, + low: Union[float, list, np.ndarray], + high: Union[float, list, np.ndarray], +) -> Union[float, np.ndarray]: + """Apply a uniform noise to a continuous multidimensional action. + + Args: + state (np.ndarray): State(s) based on which ``action`` is chosen. This is not used by the gaussian noise + exploration scheme and is put here to conform to the function signature for the exploration in continuous + action spaces. + action (np.ndarray): Action(s) chosen greedily by the policy. + min_action (Union[float, list, np.ndarray], default=None): Lower bound for the multidimensional action space. + max_action (Union[float, list, np.ndarray], default=None): Upper bound for the multidimensional action space. + low (Union[float, list, np.ndarray]): Lower bound for the noise range. + high (Union[float, list, np.ndarray]): Upper bound for the noise range. + + Returns: + Exploration actions with added noise. + """ + if min_action is None and max_action is None: + return action + np.random.uniform(low, high, size=action.shape) + else: + return np.clip(action + np.random.uniform(low, high, size=action.shape), min_action, max_action) + + +def gaussian_noise( + state: np.ndarray, + action: np.ndarray, + min_action: Union[float, list, np.ndarray] = None, + max_action: Union[float, list, np.ndarray] = None, + *, + mean: Union[float, list, np.ndarray] = 0.0, + stddev: Union[float, list, np.ndarray] = 1.0, + relative: bool = False, +) -> Union[float, np.ndarray]: + """Apply a gaussian noise to a continuous multidimensional action. + + Args: + state (np.ndarray): State(s) based on which ``action`` is chosen. This is not used by the gaussian noise + exploration scheme and is put here to conform to the function signature for the exploration in continuous + action spaces. + action (np.ndarray): Action(s) chosen greedily by the policy. + min_action (Union[float, list, np.ndarray], default=None): Lower bound for the multidimensional action space. + max_action (Union[float, list, np.ndarray], default=None): Upper bound for the multidimensional action space. + mean (Union[float, list, np.ndarray], default=0.0): Gaussian noise mean. + stddev (Union[float, list, np.ndarray], default=1.0): Standard deviation for the Gaussian noise. + relative (bool, default=False): If True, the generated noise is treated as a relative measure and will + be multiplied by the action itself before being added to the action. + + Returns: + Exploration actions with added noise (a numpy ndarray). + """ + noise = np.random.normal(loc=mean, scale=stddev, size=action.shape) + if min_action is None and max_action is None: + return action + ((noise * action) if relative else noise) + else: + return np.clip(action + ((noise * action) if relative else noise), min_action, max_action) diff --git a/maro/rl/model/__init__.py b/maro/rl/model/__init__.py index 00c54829d..f07874e34 100644 --- a/maro/rl/model/__init__.py +++ b/maro/rl/model/__init__.py @@ -1,12 +1,25 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from .abs_block import AbsBlock -from .fc_block import FullyConnectedBlock -from .learning_model import AbsCoreModel, OptimOption, SimpleMultiHeadModel +from .abs_net import AbsNet +from .algorithm_nets.ac_based import ContinuousACBasedNet, DiscreteACBasedNet +from .algorithm_nets.ddpg import ContinuousDDPGNet +from .algorithm_nets.sac import ContinuousSACNet +from .fc_block import FullyConnected +from .multi_q_net import MultiQNet +from .policy_net import ContinuousPolicyNet, DiscretePolicyNet, PolicyNet +from .q_net import ContinuousQNet, DiscreteQNet, QNet +from .v_net import VNet __all__ = [ - "AbsBlock", - "FullyConnectedBlock", - "AbsCoreModel", "OptimOption", "SimpleMultiHeadModel" + "AbsNet", + "FullyConnected", + "MultiQNet", + "ContinuousPolicyNet", "DiscretePolicyNet", "PolicyNet", + "ContinuousQNet", "DiscreteQNet", "QNet", + "VNet", + + "ContinuousACBasedNet", "DiscreteACBasedNet", + "ContinuousDDPGNet", + "ContinuousSACNet", ] diff --git a/maro/rl/model/abs_block.py b/maro/rl/model/abs_block.py deleted file mode 100644 index 2f5a1e850..000000000 --- a/maro/rl/model/abs_block.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import torch.nn as nn - - -class AbsBlock(nn.Module): - @property - def input_dim(self): - raise NotImplementedError - - @property - def output_dim(self): - raise NotImplementedError diff --git a/maro/rl/model/abs_net.py b/maro/rl/model/abs_net.py new file mode 100644 index 000000000..cb61714dd --- /dev/null +++ b/maro/rl/model/abs_net.py @@ -0,0 +1,118 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +from abc import ABCMeta +from typing import Any, Dict, Optional + +import torch.nn +from torch.optim import Optimizer + + +class AbsNet(torch.nn.Module, metaclass=ABCMeta): + """Base class for all Torch net classes. `AbsNet` defines a set of methods that will be called by upper-level + logic. All classes that inherit `AbsNet` should implement these methods. + """ + + def __init__(self) -> None: + super(AbsNet, self).__init__() + + self._optim: Optional[Optimizer] = None + + def step(self, loss: torch.Tensor) -> None: + """Run a training step to update the net's parameters according to the given loss. + + Args: + loss (torch.tensor): Loss used to update the model. + """ + self._optim.zero_grad() + loss.backward() + self._optim.step() + + def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: + """Get the gradients with respect to all parameters according to the given loss. + + Args: + loss (torch.tensor): Loss used to compute gradients. + + Returns: + Gradients (Dict[str, torch.Tensor]): A dict that contains gradients for all parameters. + """ + self._optim.zero_grad() + loss.backward() + return {name: param.grad for name, param in self.named_parameters()} + + def apply_gradients(self, grad: Dict[str, torch.Tensor]) -> None: + """Apply gradients to the net to update all parameters. + + Args: + grad (Dict[str, torch.Tensor]): A dict that contains gradients for all parameters. + """ + for name, param in self.named_parameters(): + param.grad = grad[name] + self._optim.step() + + def _forward_unimplemented(self, *input: Any) -> None: + pass + + def get_state(self) -> dict: + """Get the net's state. + + Returns: + state (dict): A object that contains the net's state. + """ + return { + "network": self.state_dict(), + "optim": self._optim.state_dict(), + } + + def set_state(self, net_state: dict) -> None: + """Set the net's state. + + Args: + net_state (dict): A dict that contains the net's state. + """ + self.load_state_dict(net_state["network"]) + self._optim.load_state_dict(net_state["optim"]) + + def soft_update(self, other_model: AbsNet, tau: float) -> None: + """Soft update the net's parameters according to another net, i.e., + self.param = self.param * (1.0 - tau) + other_model.param * tau + + Args: + other_model (AbsNet): The source net. Must has same type with the current net. + tau (float): Soft update coefficient. + """ + assert self.__class__ == other_model.__class__, \ + f"Soft update can only be done between same classes. Current model type: {self.__class__}, " \ + f"other model type: {other_model.__class__}" + + for params, other_params in zip(self.parameters(), other_model.parameters()): + params.data = (1 - tau) * params.data + tau * other_params.data + + def freeze(self) -> None: + """(Partially) freeze the current model. The users should write their own strategy to determine which + parameters to freeze. Freeze all parameters is capable in most cases. You could overwrite this method + when necessary. + """ + self.freeze_all_parameters() + + def unfreeze(self) -> None: + """(Partially) unfreeze the current model. The users should write their own strategy to determine which + parameters to freeze. Unfreeze all parameters is capable in most cases. You could overwrite this method + when necessary. + """ + self.unfreeze_all_parameters() + + def freeze_all_parameters(self) -> None: + """Freeze all parameters. + """ + for p in self.parameters(): + p.requires_grad = False + + def unfreeze_all_parameters(self) -> None: + """Unfreeze all parameters. + """ + for p in self.parameters(): + p.requires_grad = True diff --git a/maro/rl/model/algorithm_nets/__init__.py b/maro/rl/model/algorithm_nets/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/maro/rl/model/algorithm_nets/ac_based.py b/maro/rl/model/algorithm_nets/ac_based.py new file mode 100644 index 000000000..a05a6cdb1 --- /dev/null +++ b/maro/rl/model/algorithm_nets/ac_based.py @@ -0,0 +1,54 @@ +from abc import ABCMeta +from typing import Tuple + +import torch + +from maro.rl.model.policy_net import ContinuousPolicyNet, DiscretePolicyNet + + +class DiscreteACBasedNet(DiscretePolicyNet, metaclass=ABCMeta): + """Policy net for policies that are trained by Actor-Critic or PPO algorithm and with discrete actions. + + The following methods should be implemented: + - _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor: + + Overwrite one or multiple of following methods when necessary. + - freeze(self) -> None: + - unfreeze(self) -> None: + - step(self, loss: torch.Tensor) -> None: + - get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: + - apply_gradients(self, grad: Dict[str, torch.Tensor]) -> None: + - get_state(self) -> dict: + - set_state(self, net_state: dict) -> None: + """ + pass + + +class ContinuousACBasedNet(ContinuousPolicyNet, metaclass=ABCMeta): + """Policy net for policies that are trained by Actor-Critic or PPO algorithm and with continuous actions. + + The following methods should be implemented: + - _get_actions_with_logps_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]: + - _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + + Overwrite one or multiple of following methods when necessary. + - freeze(self) -> None: + - unfreeze(self) -> None: + - step(self, loss: torch.Tensor) -> None: + - get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: + - apply_gradients(self, grad: Dict[str, torch.Tensor]) -> None: + - get_state(self) -> dict: + - set_state(self, net_state: dict) -> None: + """ + + def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor: + actions, _ = self._get_actions_with_logps_impl(states, exploring) + return actions + + def _get_actions_with_probs_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]: + # Not used in Actor-Critic or PPO + pass + + def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + # Not used in Actor-Critic or PPO + pass diff --git a/maro/rl/model/algorithm_nets/ddpg.py b/maro/rl/model/algorithm_nets/ddpg.py new file mode 100644 index 000000000..5e50abe40 --- /dev/null +++ b/maro/rl/model/algorithm_nets/ddpg.py @@ -0,0 +1,39 @@ +from abc import ABCMeta +from typing import Tuple + +import torch + +from maro.rl.model.policy_net import ContinuousPolicyNet + + +class ContinuousDDPGNet(ContinuousPolicyNet, metaclass=ABCMeta): + """Policy net for policies that are trained by DDPG. + + The following methods should be implemented: + - _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor: + + Overwrite one or multiple of following methods when necessary. + - freeze(self) -> None: + - unfreeze(self) -> None: + - step(self, loss: torch.Tensor) -> None: + - get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: + - apply_gradients(self, grad: Dict[str, torch.Tensor]) -> None: + - get_state(self) -> dict: + - set_state(self, net_state: dict) -> None: + """ + + def _get_actions_with_probs_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]: + # Not used in DDPG + pass + + def _get_actions_with_logps_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]: + # Not used in DDPG + pass + + def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + # Not used in DDPG + pass + + def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + # Not used in DDPG + pass diff --git a/maro/rl/model/algorithm_nets/sac.py b/maro/rl/model/algorithm_nets/sac.py new file mode 100644 index 000000000..fca970989 --- /dev/null +++ b/maro/rl/model/algorithm_nets/sac.py @@ -0,0 +1,39 @@ +from abc import ABCMeta +from typing import Tuple + +import torch + +from maro.rl.model.policy_net import ContinuousPolicyNet + + +class ContinuousSACNet(ContinuousPolicyNet, metaclass=ABCMeta): + """Policy net for policies that are trained by SAC. + + The following methods should be implemented: + - _get_actions_with_logps_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]: + + Overwrite one or multiple of following methods when necessary. + - freeze(self) -> None: + - unfreeze(self) -> None: + - step(self, loss: torch.Tensor) -> None: + - get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: + - apply_gradients(self, grad: Dict[str, torch.Tensor]) -> None: + - get_state(self) -> dict: + - set_state(self, net_state: dict) -> None: + """ + + def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor: + actions, _ = self._get_actions_with_logps_impl(states, exploring) + return actions + + def _get_actions_with_probs_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]: + # Not used in SAC + pass + + def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + # Not used in SAC + pass + + def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + # Not used in SAC + pass diff --git a/maro/rl/model/fc_block.py b/maro/rl/model/fc_block.py index 829b7d581..65c803284 100644 --- a/maro/rl/model/fc_block.py +++ b/maro/rl/model/fc_block.py @@ -2,54 +2,59 @@ # Licensed under the MIT license. from collections import OrderedDict +from typing import Any, List, Optional, Type import torch import torch.nn as nn -from .abs_block import AbsBlock - -class FullyConnectedBlock(AbsBlock): +class FullyConnected(nn.Module): """Fully connected network with optional batch normalization, activation and dropout components. Args: - name (str): Network name. input_dim (int): Network input dimension. output_dim (int): Network output dimension. - hidden_dims ([int]): Dimensions of hidden layers. Its length is the number of hidden layers. - activation: A ``torch.nn`` activation type. If None, there will be no activation. Defaults to LeakyReLU. - head (bool): If true, this block will be the top block of the full model and the top layer of this block - will be the final output layer. Defaults to False. - softmax (bool): If true, the output of the net will be a softmax transformation of the top layer's - output. Defaults to False. - batch_norm (bool): If true, batch normalization will be performed at each layer. - skip_connection (bool): If true, a skip connection will be built between the bottom (input) layer and - top (output) layer. Defaults to False. - dropout_p (float): Dropout probability. Defaults to None, in which case there is no drop-out. - gradient_threshold (float): Gradient clipping threshold. Defaults to None, in which case not gradient clipping + hidden_dims (List[int]): Dimensions of hidden layers. Its length is the number of hidden layers. For example, + `hidden_dims=[128, 256]` refers to two hidden layers with output dim of 128 and 256, respectively. + activation (Optional[Type[torch.nn.Module], default=nn.ReLU): Activation class provided by ``torch.nn`` or a + customized activation class. If None, there will be no activation. + head (bool, default=False): If true, this block will be the top block of the full model and the top layer + of this block will be the final output layer. + softmax (bool, default=False): If true, the output of the net will be a softmax transformation of the top + layer's output. + batch_norm (bool, default=False): If true, batch normalization will be performed at each layer. + skip_connection (bool, default=False): If true, a skip connection will be built between the bottom (input) + layer and top (output) layer. Defaults to False. + dropout_p (float, default=None): Dropout probability. If it is None, there will be no drop-out. + gradient_threshold (float, default=None): Gradient clipping threshold. If it is None, no gradient clipping is performed. + name (str, default=None): Network name. """ + + def _forward_unimplemented(self, *input: Any) -> None: + pass + def __init__( self, input_dim: int, output_dim: int, - hidden_dims: [int], - activation=nn.LeakyReLU, + hidden_dims: List[int], + activation: Optional[Type[torch.nn.Module]] = nn.ReLU, head: bool = False, softmax: bool = False, batch_norm: bool = False, skip_connection: bool = False, dropout_p: float = None, gradient_threshold: float = None, - name: str = None - ): - super().__init__() + name: str = None, + ) -> None: + super(FullyConnected, self).__init__() self._input_dim = input_dim self._hidden_dims = hidden_dims if hidden_dims is not None else [] self._output_dim = output_dim # network features - self._activation = activation + self._activation = activation() if activation else None self._head = head self._softmax = nn.Softmax(dim=1) if softmax else None self._batch_norm = batch_norm @@ -78,26 +83,26 @@ def __init__( self._name = name - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: out = self._net(x) if self._skip_connection: out += x return self._softmax(out) if self._softmax else out @property - def name(self): + def name(self) -> str: return self._name @property - def input_dim(self): + def input_dim(self) -> int: return self._input_dim @property - def output_dim(self): + def output_dim(self) -> int: return self._output_dim - def _build_layer(self, input_dim, output_dim, head: bool = False): - """Build basic layer. + def _build_layer(self, input_dim: int, output_dim: int, head: bool = False) -> torch.nn.Module: + """Build a basic layer. BN -> Linear -> Activation -> Dropout """ @@ -106,7 +111,7 @@ def _build_layer(self, input_dim, output_dim, head: bool = False): components.append(("batch_norm", nn.BatchNorm1d(input_dim))) components.append(("linear", nn.Linear(input_dim, output_dim))) if not head and self._activation is not None: - components.append(("activation", self._activation())) + components.append(("activation", self._activation)) if not head and self._dropout_p: components.append(("dropout", nn.Dropout(p=self._dropout_p))) return nn.Sequential(OrderedDict(components)) diff --git a/maro/rl/model/learning_model.py b/maro/rl/model/learning_model.py deleted file mode 100644 index cb8bfd6ab..000000000 --- a/maro/rl/model/learning_model.py +++ /dev/null @@ -1,198 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from abc import abstractmethod -from typing import Dict, List, Union - -import torch -import torch.nn as nn - -from maro.utils import clone -from maro.utils.exception.rl_toolkit_exception import MissingOptimizer - - -class OptimOption: - """Model optimization options. - Args: - optim_cls: Subclass of torch.optim.Optimizer. - optim_params (dict): Parameters for the optimizer class. - scheduler_cls: torch lr_scheduler class. Defaults to None. - scheduler_params (dict): Parameters for the scheduler class. Defaults to None. - """ - __slots__ = ["optim_cls", "optim_params", "scheduler_cls", "scheduler_params"] - - def __init__(self, optim_cls, optim_params: dict, scheduler_cls=None, scheduler_params: dict = None): - self.optim_cls = optim_cls - self.optim_params = optim_params - self.scheduler_cls = scheduler_cls - self.scheduler_params = scheduler_params - - -class AbsCoreModel(nn.Module): - """Trainable model that consists of multiple network components. - - Args: - component (Union[nn.Module, Dict[str, nn.Module]]): Network component(s) comprising the model. - optim_option (Union[OptimOption, Dict[str, OptimOption]]): Optimizer options for the components. - If none, no optimizer will be created for the model which means the model is not trainable. - If it is a OptimOption instance, a single optimizer will be created to jointly optimize all - parameters of the model. If it is a dictionary of OptimOptions, the keys will be matched against - the component names and optimizers created for them. Note that it is possible to freeze certain - components while optimizing others by providing a subset of the keys in ``component``. - Defaults toNone. - """ - def __init__( - self, - component: Union[nn.Module, Dict[str, nn.Module]], - optim_option: Union[OptimOption, Dict[str, OptimOption]] = None - ): - super().__init__() - self._component = component if isinstance(component, nn.Module) else nn.ModuleDict(component) - if optim_option is None: - self.optimizer = None - self.scheduler = None - self.eval() - for param in self.parameters(): - param.requires_grad = False - else: - if isinstance(optim_option, dict): - self.optimizer = {} - self.scheduler = {} - for name, opt in optim_option.items(): - self.optimizer[name] = opt.optim_cls(self._component[name].parameters(), **opt.optim_params) - if opt.scheduler_cls: - self.scheduler[name] = opt.scheduler_cls(self.optimizer[name], **opt.scheduler_params) - else: - self.optimizer = optim_option.optim_cls(self.parameters(), **optim_option.optim_params) - if optim_option.scheduler_cls: - self.scheduler = optim_option.scheduler_cls(self.optimizer, **optim_option.scheduler_params) - - @property - def trainable(self) -> bool: - return self.optimizer is not None - - @abstractmethod - def forward(self, *args, **kwargs): - raise NotImplementedError - - def step(self, loss): - """Use the loss to back-propagate gradients and apply them to the underlying parameters.""" - if self.optimizer is None: - raise MissingOptimizer("No optimizer registered to the model") - if isinstance(self.optimizer, dict): - for optimizer in self.optimizer.values(): - optimizer.zero_grad() - else: - self.optimizer.zero_grad() - - # Obtain gradients through back-propagation - loss.backward() - - # Apply gradients - if isinstance(self.optimizer, dict): - for optimizer in self.optimizer.values(): - optimizer.step() - else: - self.optimizer.step() - - def update_learning_rate(self, component_name: Union[str, List[str]] = None): - if not isinstance(self.scheduler, dict): - self.scheduler.step() - elif isinstance(component_name, str): - if component_name not in self.scheduler: - raise KeyError(f"Component {component_name} does not have a learning rate scheduler") - self.scheduler[component_name].step() - elif isinstance(component_name, list): - for key in component_name: - if key not in self.scheduler: - raise KeyError(f"Component {key} does not have a learning rate scheduler") - self.scheduler[key].step() - else: - for sch in self.scheduler.values(): - sch.step() - - def soft_update(self, other_model: nn.Module, tau: float): - for params, other_params in zip(self.parameters(), other_model.parameters()): - params.data = (1 - tau) * params.data + tau * other_params.data - - def copy(self, with_optimizer: bool = False): - model_copy = clone(self) - if not with_optimizer: - model_copy.optimizer = None - model_copy.scheduler = None - - return model_copy - - -class SimpleMultiHeadModel(AbsCoreModel): - """A compound network structure that consists of multiple task heads and an optional shared stack. - - Args: - component (Union[nn.Module, Dict[str, nn.Module]]): Network component(s) comprising the model. - All components must have the same input dimension except the one designated as the shared - component by ``shared_component_name``. - optim_option (Union[OptimOption, Dict[str, OptimOption]]): Optimizer option for - the components. Defaults to None. - shared_component_name (str): Name of the network component to be designated as the shared component at the - bottom of the architecture. Must be None or a key in ``component``. If only a single component - is present, this is ignored. Defaults to None. - """ - def __init__( - self, - component: Union[nn.Module, Dict[str, nn.Module]], - optim_option: Union[OptimOption, Dict[str, OptimOption]] = None, - shared_component_name: str = None - ): - super().__init__(component, optim_option=optim_option) - if isinstance(component, dict): - if shared_component_name is not None: - assert (shared_component_name in component), ( - f"shared_component_name must be one of {list(component.keys())}, got {shared_component_name}" - ) - self._task_names = [name for name in component if name != shared_component_name] - else: - self._task_names = None - self._shared_component_name = shared_component_name - - @property - def task_names(self): - return self._task_names - - def _forward(self, inputs, task_name: str = None): - if not isinstance(self._component, nn.ModuleDict): - return self._component(inputs) - - if self._shared_component_name is not None: - inputs = self._component[self._shared_component_name](inputs) # features - - if task_name is None: - return {name: self._component[name](inputs) for name in self._task_names} - - if isinstance(task_name, list): - return {name: self._component[name](inputs) for name in task_name} - else: - return self._component[task_name](inputs) - - def forward(self, inputs, task_name: Union[str, List[str]] = None, training: bool = True): - """Feedforward computations for the given head(s). - - Args: - inputs: Inputs to the model. - task_name (str): The name of the task for which the network output is required. If the model contains only - one task module, the task_name is ignored and the output of that module will be returned. If the model - contains multiple task modules, then 1) if task_name is None, the output from all task modules will be - returned in the form of a dictionary; 2) if task_name is a list, the outputs from the task modules - specified in the list will be returned in the form of a dictionary; 3) if this is a single string, - the output from the corresponding task module will be returned. - training (bool): If true, all torch submodules will be set to training mode, and auto-differentiation - will be turned on. Defaults to True. - - Returns: - Outputs from the required head(s). - """ - self.train(mode=training) - if training: - return self._forward(inputs, task_name) - - with torch.no_grad(): - return self._forward(inputs, task_name) diff --git a/maro/rl/model/multi_q_net.py b/maro/rl/model/multi_q_net.py new file mode 100644 index 000000000..9fb4e1cb7 --- /dev/null +++ b/maro/rl/model/multi_q_net.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import ABCMeta, abstractmethod +from typing import List + +import torch + +from maro.rl.utils import match_shape, SHAPE_CHECK_FLAG + +from .abs_net import AbsNet + + +class MultiQNet(AbsNet, metaclass=ABCMeta): + """Abstract net for multi-agent Q functions. + + Args: + state_dim (int): Dimension of states. + action_dims (List[int]): Dimensions of Dimension of multi-agents' actions. Its length equals the + number of agents. + """ + + def __init__(self, state_dim: int, action_dims: List[int]) -> None: + super(MultiQNet, self).__init__() + self._state_dim = state_dim + self._action_dims = action_dims + + @property + def state_dim(self) -> int: + return self._state_dim + + @property + def action_dims(self) -> List[int]: + return self._action_dims + + @property + def agent_num(self) -> int: + return len(self._action_dims) + + def _shape_check(self, states: torch.Tensor, actions: List[torch.Tensor] = None) -> bool: + """Check whether the states and actions have valid shapes. + + Args: + states (torch.Tensor): State tensor. + actions (List[torch.Tensor], default=None): Action tensors. It length must be equal to the number of agents. + If it is None, it means we only check state tensor's shape. + + Returns: + valid_flag (bool): whether the states and actions have valid shapes. + """ + if not SHAPE_CHECK_FLAG: + return True + else: + if states.shape[0] == 0 or not match_shape(states, (None, self.state_dim)): + return False + if actions is not None: + if len(actions) != self.agent_num: + return False + for action, dim in zip(actions, self.action_dims): + if not match_shape(action, (states.shape[0], dim)): + return False + return True + + def q_values(self, states: torch.Tensor, actions: List[torch.Tensor]) -> torch.Tensor: + """Get Q-values according to states and actions. + + Args: + states (torch.Tensor): States. + actions (List[torch.Tensor]): List of actions. + + Returns: + q (torch.Tensor): Q-values with shape [batch_size]. + """ + assert self._shape_check(states, actions) + q = self._get_q_values(states, actions) + assert match_shape(q, (states.shape[0],)), \ + f"Q-value shape check failed. Expecting: {(states.shape[0],)}, actual: {q.shape}." # [B] + return q + + @abstractmethod + def _get_q_values(self, states: torch.Tensor, actions: List[torch.Tensor]) -> torch.Tensor: + """Implementation of `q_values`. + """ + raise NotImplementedError diff --git a/maro/rl/model/policy_net.py b/maro/rl/model/policy_net.py new file mode 100644 index 000000000..63b03e10a --- /dev/null +++ b/maro/rl/model/policy_net.py @@ -0,0 +1,211 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import ABCMeta, abstractmethod +from typing import Tuple + +import torch.nn +from torch.distributions import Categorical + +from maro.rl.utils import match_shape, SHAPE_CHECK_FLAG + +from .abs_net import AbsNet + + +class PolicyNet(AbsNet, metaclass=ABCMeta): + """Base class for all nets that serve as policy cores. It has the concept of 'state' and 'action'. + + Args: + state_dim (int): Dimension of states. + action_dim (int): Dimension of actions. + """ + + def __init__(self, state_dim: int, action_dim: int) -> None: + super(PolicyNet, self).__init__() + self._state_dim = state_dim + self._action_dim = action_dim + + @property + def state_dim(self) -> int: + return self._state_dim + + @property + def action_dim(self) -> int: + return self._action_dim + + def get_actions(self, states: torch.Tensor, exploring: bool) -> torch.Tensor: + assert self._shape_check(states=states), \ + f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}." + + actions = self._get_actions_impl(states, exploring) + + assert self._shape_check(states=states, actions=actions), \ + f"Actions shape check failed. Expecting: {(states.shape[0], self.action_dim)}, actual: {actions.shape}." + + return actions + + def get_actions_with_probs(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]: + assert self._shape_check(states=states), \ + f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}." + + actions, probs = self._get_actions_with_probs_impl(states, exploring) + + assert self._shape_check(states=states, actions=actions), \ + f"Actions shape check failed. Expecting: {(states.shape[0], self.action_dim)}, actual: {actions.shape}." + assert len(probs.shape) == 1 and probs.shape[0] == states.shape[0] + + return actions, probs + + def get_actions_with_logps(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]: + assert self._shape_check(states=states), \ + f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}." + + actions, logps = self._get_actions_with_logps_impl(states, exploring) + + assert self._shape_check(states=states, actions=actions), \ + f"Actions shape check failed. Expecting: {(states.shape[0], self.action_dim)}, actual: {actions.shape}." + assert len(logps.shape) == 1 and logps.shape[0] == states.shape[0] + + return actions, logps + + def get_states_actions_probs(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + assert self._shape_check(states=states), \ + f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}." + + probs = self._get_states_actions_probs_impl(states, actions) + + assert len(probs.shape) == 1 and probs.shape[0] == states.shape[0] + + return probs + + def get_states_actions_logps(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + assert self._shape_check(states=states), \ + f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}." + + logps = self._get_states_actions_logps_impl(states, actions) + + assert len(logps.shape) == 1 and logps.shape[0] == states.shape[0] + + return logps + + @abstractmethod + def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def _get_actions_with_probs_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + @abstractmethod + def _get_actions_with_logps_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + @abstractmethod + def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def _shape_check(self, states: torch.Tensor, actions: torch.Tensor = None) -> bool: + """Check whether the states and actions have valid shapes. + + Args: + states (torch.Tensor): State tensor. + actions (torch.Tensor, default=None): Action tensor. If it is None, it means we only check state tensor's + shape. + + Returns: + valid_flag (bool): whether the states and actions have valid shapes. + """ + if not SHAPE_CHECK_FLAG: + return True + else: + if states.shape[0] == 0: + return False + if not match_shape(states, (None, self.state_dim)): + return False + + if actions is not None: + if not match_shape(actions, (states.shape[0], self.action_dim)): + return False + return True + + +class DiscretePolicyNet(PolicyNet, metaclass=ABCMeta): + """Policy network for discrete action spaces. + + Args: + state_dim (int): Dimension of states. + action_num (int): Number of actions. + """ + + def __init__(self, state_dim: int, action_num: int) -> None: + super(DiscretePolicyNet, self).__init__(state_dim=state_dim, action_dim=1) + self._action_num = action_num + + @property + def action_num(self) -> int: + return self._action_num + + def get_action_probs(self, states: torch.Tensor) -> torch.Tensor: + """Get the probabilities for all possible actions in the action space. + + Args: + states (torch.Tensor): States. + + Returns: + action_probs (torch.Tensor): Probability matrix with shape [batch_size, action_num]. + """ + assert self._shape_check(states=states), \ + f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}." + action_probs = self._get_action_probs_impl(states) + assert match_shape(action_probs, (states.shape[0], self.action_num)), \ + f"Action probabilities shape check failed. Expecting: {(states.shape[0], self.action_num)}, " \ + f"actual: {action_probs.shape}." + return action_probs + + @abstractmethod + def _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor: + """Implementation of `get_action_probs`. The core logic of a discrete policy net should be implemented here. + """ + raise NotImplementedError + + def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor: + actions, _ = self._get_actions_with_probs_impl(states, exploring) + return actions + + def _get_actions_with_probs_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]: + probs = self.get_action_probs(states) + if exploring: + distribution = Categorical(probs) + actions = distribution.sample().unsqueeze(1) + return actions, probs.gather(1, actions).squeeze(-1) + else: + probs, actions = probs.max(dim=1) + return actions.unsqueeze(1), probs + + def _get_actions_with_logps_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]: + actions, probs = self._get_actions_with_probs_impl(states, exploring) + return actions, torch.log(probs) + + def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + probs = self.get_action_probs(states) + return probs.gather(1, actions).squeeze(-1) + + def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + probs = self._get_states_actions_probs_impl(states, actions) + return torch.log(probs) + + +class ContinuousPolicyNet(PolicyNet, metaclass=ABCMeta): + """Policy network for continuous action spaces. + + Args: + state_dim (int): Dimension of states. + action_dim (int): Dimension of actions. + """ + + def __init__(self, state_dim: int, action_dim: int) -> None: + super(ContinuousPolicyNet, self).__init__(state_dim=state_dim, action_dim=action_dim) diff --git a/maro/rl/model/q_net.py b/maro/rl/model/q_net.py new file mode 100644 index 000000000..777cf0b9d --- /dev/null +++ b/maro/rl/model/q_net.py @@ -0,0 +1,134 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import ABCMeta, abstractmethod + +import torch + +from maro.rl.utils import match_shape, SHAPE_CHECK_FLAG + +from .abs_net import AbsNet + + +class QNet(AbsNet, metaclass=ABCMeta): + """Abstract Q-value network. + + Args: + state_dim (int): Dimension of states. + action_dim (int): Dimension of actions. + """ + + def __init__(self, state_dim: int, action_dim: int) -> None: + super(QNet, self).__init__() + self._state_dim = state_dim + self._action_dim = action_dim + + @property + def state_dim(self) -> int: + return self._state_dim + + @property + def action_dim(self) -> int: + return self._action_dim + + def _shape_check(self, states: torch.Tensor, actions: torch.Tensor = None) -> bool: + """Check whether the states and actions have valid shapes. + + Args: + states (torch.Tensor): State tensor. + actions (torch.Tensor, default=None): Action tensor. If it is None, it means we only check state tensor's + shape. + + Returns: + valid_flag (bool): whether the states and actions have valid shapes. + """ + if not SHAPE_CHECK_FLAG: + return True + else: + if states.shape[0] == 0 or not match_shape(states, (None, self.state_dim)): + return False + if actions is not None: + if not match_shape(actions, (states.shape[0], self.action_dim)): + return False + return True + + def q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + """Get Q-values according to states and actions. + + Args: + states (torch.Tensor): States. + actions (torch.Tensor): Actions. + + Returns: + q (torch.Tensor): Q-values with shape [batch_size]. + """ + assert self._shape_check(states=states, actions=actions), \ + f"States or action shape check failed. Expecting: " \ + f"states = {('BATCH_SIZE', self.state_dim)}, action = {('BATCH_SIZE', self.action_dim)}. " \ + f"Actual: states = {states.shape}, action = {actions.shape}." + q = self._get_q_values(states, actions) + assert match_shape(q, (states.shape[0],)), \ + f"Q-value shape check failed. Expecting: {(states.shape[0],)}, actual: {q.shape}." # [B] + return q + + @abstractmethod + def _get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + """Implementation of `q_values`. + """ + raise NotImplementedError + + +class DiscreteQNet(QNet, metaclass=ABCMeta): + """Q-value network for discrete action spaces. + + Args: + state_dim (int): Dimension of states. + action_num (int): Number of actions. + """ + + def __init__(self, state_dim: int, action_num: int) -> None: + super(DiscreteQNet, self).__init__(state_dim=state_dim, action_dim=1) + self._action_num = action_num + + @property + def action_num(self) -> int: + return self._action_num + + def q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor: + """Get Q-values for all actions according to states. + + Args: + states (torch.Tensor): States. + + Returns: + q (torch.Tensor): Q-values for all actions. The returned value has the shape [batch_size, action_num]. + """ + assert self._shape_check(states=states), \ + f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}." + q = self._get_q_values_for_all_actions(states) + assert match_shape(q, (states.shape[0], self.action_num)), \ + f"Q-value matrix shape check failed. Expecting: {(states.shape[0], self.action_num)}, " \ + f"actual: {q.shape}." # [B, action_num] + return q + + def _get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + q = self.q_values_for_all_actions(states) # [B, action_num] + return q.gather(1, actions.long()).reshape(-1) # [B, action_num] + [B, 1] => [B] + + @abstractmethod + def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor: + """Implementation of `q_values_for_all_actions`. + """ + raise NotImplementedError + + +class ContinuousQNet(QNet, metaclass=ABCMeta): + """Q-value network for continuous action spaces. + + Args: + state_dim (int): Dimension of states. + action_dim (int): Dimension of actions. + """ + + def __init__(self, state_dim: int, action_dim: int) -> None: + super(ContinuousQNet, self).__init__(state_dim=state_dim, action_dim=action_dim) diff --git a/maro/rl/model/v_net.py b/maro/rl/model/v_net.py new file mode 100644 index 000000000..6451aa514 --- /dev/null +++ b/maro/rl/model/v_net.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import ABCMeta, abstractmethod + +import torch + +from maro.rl.utils import match_shape, SHAPE_CHECK_FLAG + +from .abs_net import AbsNet + + +class VNet(AbsNet, metaclass=ABCMeta): + """V-value network. + + Args: + state_dim (int): Dimension of states. + """ + + def __init__(self, state_dim: int) -> None: + super(VNet, self).__init__() + self._state_dim = state_dim + + @property + def state_dim(self) -> int: + return self._state_dim + + def _shape_check(self, states: torch.Tensor) -> bool: + """Check whether the states have valid shapes. + + Args: + states (torch.Tensor): State tensor. + + Returns: + valid_flag (bool): whether the states and actions have valid shapes. + """ + if not SHAPE_CHECK_FLAG: + return True + else: + return states.shape[0] > 0 and match_shape(states, (None, self.state_dim)) + + def v_values(self, states: torch.Tensor) -> torch.Tensor: + """Get V-values according to states. + + Args: + states (torch.Tensor): States. + + Returns: + v (torch.Tensor): V-values with shape [batch_size]. + """ + assert self._shape_check(states), \ + f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}." + v = self._get_v_values(states) + assert match_shape(v, (states.shape[0],)), \ + f"V-value shape check failed. Expecting: {(states.shape[0],)}, actual: {v.shape}." # [B] + return v + + @abstractmethod + def _get_v_values(self, states: torch.Tensor) -> torch.Tensor: + """Implementation of `v_values`. + """ + raise NotImplementedError diff --git a/maro/rl/policy/__init__.py b/maro/rl/policy/__init__.py new file mode 100644 index 000000000..485c6f2bb --- /dev/null +++ b/maro/rl/policy/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from .abs_policy import AbsPolicy, DummyPolicy, RLPolicy, RuleBasedPolicy +from .continuous_rl_policy import ContinuousRLPolicy +from .discrete_rl_policy import DiscretePolicyGradient, DiscreteRLPolicy, ValueBasedPolicy + +__all__ = [ + "AbsPolicy", "DummyPolicy", "RLPolicy", "RuleBasedPolicy", + "ContinuousRLPolicy", + "DiscretePolicyGradient", "DiscreteRLPolicy", "ValueBasedPolicy", +] diff --git a/maro/rl/policy/abs_policy.py b/maro/rl/policy/abs_policy.py new file mode 100644 index 000000000..86a980275 --- /dev/null +++ b/maro/rl/policy/abs_policy.py @@ -0,0 +1,386 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch + +from maro.rl.utils import match_shape, ndarray_to_tensor, SHAPE_CHECK_FLAG + + +class AbsPolicy(object, metaclass=ABCMeta): + """Abstract policy class. A policy takes states as inputs and generates actions as outputs. A policy cannot + update itself. It has to be updated by external trainers through public interfaces. + + Args: + name (str): Name of this policy. + trainable (bool): Whether this policy is trainable. + """ + + def __init__(self, name: str, trainable: bool) -> None: + super(AbsPolicy, self).__init__() + self._name = name + self._trainable = trainable + + @abstractmethod + def get_actions(self, states: object) -> object: + """Get actions according to states. + + Args: + states (object): States. + + Returns: + actions (object): Actions. + """ + raise NotImplementedError + + @property + def name(self) -> str: + return self._name + + @property + def trainable(self) -> bool: + return self._trainable + + def set_name(self, name: str) -> None: + self._name = name + + @abstractmethod + def explore(self) -> None: + """Set the policy to exploring mode. + """ + raise NotImplementedError + + @abstractmethod + def exploit(self) -> None: + """Set the policy to exploiting mode. + """ + raise NotImplementedError + + @abstractmethod + def eval(self) -> None: + """Switch the policy to evaluation mode. + """ + raise NotImplementedError + + @abstractmethod + def train(self) -> None: + """Switch the policy to training mode. + """ + raise NotImplementedError + + def to_device(self, device: torch.device) -> None: + pass + + +class DummyPolicy(AbsPolicy): + """Dummy policy that takes no actions. + """ + + def __init__(self) -> None: + super(DummyPolicy, self).__init__(name='DUMMY_POLICY', trainable=False) + + def get_actions(self, states: object) -> None: + return None + + def explore(self) -> None: + pass + + def exploit(self) -> None: + pass + + def eval(self) -> None: + pass + + def train(self) -> None: + pass + + +class RuleBasedPolicy(AbsPolicy, metaclass=ABCMeta): + """Rule-based policy. The user should define the rule of this policy, and a rule-based policy is not trainable. + """ + + def __init__(self, name: str) -> None: + super(RuleBasedPolicy, self).__init__(name=name, trainable=False) + + def get_actions(self, states: List[object]) -> List[object]: + return self._rule(states) + + @abstractmethod + def _rule(self, states: List[object]) -> List[object]: + raise NotImplementedError + + def explore(self) -> None: + pass + + def exploit(self) -> None: + pass + + def eval(self) -> None: + pass + + def train(self) -> None: + pass + + +class RLPolicy(AbsPolicy, metaclass=ABCMeta): + """Reinforcement learning policy. + + Args: + name (str): Name of the policy. + state_dim (int): Dimension of states. + action_dim (int): Dimension of actions. + trainable (bool, default=True): Whether this policy is trainable. + """ + + def __init__( + self, + name: str, + state_dim: int, + action_dim: int, + is_discrete_action: bool, + trainable: bool = True, + ) -> None: + super(RLPolicy, self).__init__(name=name, trainable=trainable) + self._state_dim = state_dim + self._action_dim = action_dim + self._is_exploring = False + + self._device: Optional[torch.device] = None + + self.is_discrete_action = is_discrete_action + + @property + def state_dim(self) -> int: + return self._state_dim + + @property + def action_dim(self) -> int: + return self._action_dim + + @property + def is_exploring(self) -> bool: + """Whether this policy is under exploring mode. + """ + return self._is_exploring + + def explore(self) -> None: + """Set the policy to exploring mode. + """ + self._is_exploring = True + + def exploit(self) -> None: + """Set the policy to exploiting mode. + """ + self._is_exploring = False + + @abstractmethod + def train_step(self, loss: torch.Tensor) -> None: + """Run a training step to update the policy according to the given loss. + + Args: + loss (torch.Tensor): Loss used to update the policy. + """ + raise NotImplementedError + + @abstractmethod + def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: + """Get the gradients with respect to all parameters of the internal nets according to the given loss. + + Args: + loss (torch.tensor): Loss used to update the model. + + Returns: + grad (Dict[str, torch.Tensor]): A dict that contains gradients of the internal nets for all parameters. + """ + raise NotImplementedError + + @abstractmethod + def apply_gradients(self, grad: dict) -> None: + """Apply gradients to the net to update all parameters. + + Args: + grad (Dict[str, torch.Tensor]): A dict that contains gradients for all parameters. + """ + raise NotImplementedError + + def get_actions(self, states: np.ndarray) -> np.ndarray: + actions = self.get_actions_tensor(ndarray_to_tensor(states, device=self._device)) + return actions.detach().cpu().numpy() + + def get_actions_tensor(self, states: torch.Tensor) -> torch.Tensor: + assert self._shape_check(states=states), \ + f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}." + + actions = self._get_actions_impl(states) + + assert self._shape_check(states=states, actions=actions), \ + f"Actions shape check failed. Expecting: {(states.shape[0], self.action_dim)}, actual: {actions.shape}." + + return actions + + def get_actions_with_probs(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert self._shape_check(states=states), \ + f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}." + + actions, probs = self._get_actions_with_probs_impl(states) + + assert self._shape_check(states=states, actions=actions), \ + f"Actions shape check failed. Expecting: {(states.shape[0], self.action_dim)}, actual: {actions.shape}." + assert len(probs.shape) == 1 and probs.shape[0] == states.shape[0] + + return actions, probs + + def get_actions_with_logps(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert self._shape_check(states=states), \ + f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}." + + actions, logps = self._get_actions_with_logps_impl(states) + + assert self._shape_check(states=states, actions=actions), \ + f"Actions shape check failed. Expecting: {(states.shape[0], self.action_dim)}, actual: {actions.shape}." + assert len(logps.shape) == 1 and logps.shape[0] == states.shape[0] + + return actions, logps + + def get_states_actions_probs(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + assert self._shape_check(states=states), \ + f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}." + + probs = self._get_states_actions_probs_impl(states, actions) + + assert len(probs.shape) == 1 and probs.shape[0] == states.shape[0] + + return probs + + def get_states_actions_logps(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + assert self._shape_check(states=states), \ + f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}." + + logps = self._get_states_actions_logps_impl(states, actions) + + assert len(logps.shape) == 1 and logps.shape[0] == states.shape[0] + + return logps + + @abstractmethod + def _get_actions_impl(self, states: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def _get_actions_with_probs_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + @abstractmethod + def _get_actions_with_logps_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + @abstractmethod + def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def freeze(self) -> None: + """(Partially) freeze the current model. The users should write their own strategy to determine which + parameters to freeze. + """ + raise NotImplementedError + + @abstractmethod + def unfreeze(self) -> None: + """(Partially) unfreeze the current model. The users should write their own strategy to determine which + parameters to freeze. + """ + raise NotImplementedError + + @abstractmethod + def get_state(self) -> object: + """Get the state of the policy. + """ + raise NotImplementedError + + @abstractmethod + def set_state(self, policy_state: dict) -> None: + """Set the state of the policy. + """ + raise NotImplementedError + + @abstractmethod + def soft_update(self, other_policy: RLPolicy, tau: float) -> None: + """Soft update the policy's parameters according to another policy. + + Args: + other_policy (AbsNet): The source policy. Must has same type with the current policy. + tau (float): Soft update coefficient. + """ + raise NotImplementedError + + def _shape_check( + self, + states: torch.Tensor, + actions: torch.Tensor = None, + ) -> bool: + """Check whether the states and actions have valid shapes. + + Args: + states (torch.Tensor): State tensor. + actions (torch.Tensor, default=None): Action tensor. If it is None, it means we only check state tensor's + shape. + + Returns: + valid_flag (bool): whether the states and actions have valid shapes. + """ + if not SHAPE_CHECK_FLAG: + return True + else: + if states.shape[0] == 0: + return False + if not match_shape(states, (None, self.state_dim)): + return False + + if actions is not None: + if not match_shape(actions, (states.shape[0], self.action_dim)): + return False + return True + + @abstractmethod + def _post_check(self, states: torch.Tensor, actions: torch.Tensor) -> bool: + """Check whether the generated action tensor is valid, i.e., has matching shape with states tensor. + + Args: + states (torch.Tensor): State tensor. + actions (torch.Tensor): Action tensor. + + Returns: + valid_flag (bool): whether the action tensor is valid. + """ + raise NotImplementedError + + def to_device(self, device: torch.device) -> None: + """Assign the current policy to a specific device. + + Args: + device (torch.device): The target device. + """ + if self._device is None: + self._device = device + self._to_device_impl(device) + elif self._device != device: + raise ValueError( + f"Policy {self.name} has already been assigned to device {self._device} " + f"and cannot be re-assigned to device {device}" + ) + + @abstractmethod + def _to_device_impl(self, device: torch.device) -> None: + """Implementation of `to_device`. + """ + pass diff --git a/maro/rl/policy/continuous_rl_policy.py b/maro/rl/policy/continuous_rl_policy.py new file mode 100644 index 000000000..22e14b610 --- /dev/null +++ b/maro/rl/policy/continuous_rl_policy.py @@ -0,0 +1,127 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch + +from maro.rl.model import ContinuousPolicyNet + +from .abs_policy import RLPolicy + + +def _parse_action_range( + action_dim: int, + action_range: Tuple[Union[float, List[float]], Union[float, List[float]]], +) -> Tuple[Optional[List[float]], Optional[List[float]]]: + lower, upper = action_range + + if isinstance(lower, float): + lower = [lower] * action_dim + if isinstance(upper, float): + upper = [upper] * action_dim + + if not (action_dim == len(lower) == len(upper)): + return None, None + + for lval, uval in zip(lower, upper): + if lval >= uval: + return None, None + + return lower, upper + + +class ContinuousRLPolicy(RLPolicy): + """RL policy for continuous action spaces. + + Args: + name (str): Name of the policy. + action_range (Tuple[Union[float, List[float]], Union[float, List[float]]]): Value range of actions. + Both the lower bound and the upper bound could be float or array. If it is an array, it should contain + the bound for every dimension. If it is a float, it will be broadcast to all dimensions. + policy_net (ContinuousPolicyNet): The core net of this policy. + trainable (bool, default=True): Whether this policy is trainable. + """ + + def __init__( + self, + name: str, + action_range: Tuple[Union[float, List[float]], Union[float, List[float]]], + policy_net: ContinuousPolicyNet, + trainable: bool = True, + ) -> None: + assert isinstance(policy_net, ContinuousPolicyNet) + + super(ContinuousRLPolicy, self).__init__( + name=name, state_dim=policy_net.state_dim, action_dim=policy_net.action_dim, + trainable=trainable, is_discrete_action=False, + ) + + self._lbounds, self._ubounds = _parse_action_range(self.action_dim, action_range) + assert self._lbounds is not None and self._ubounds is not None + + self._policy_net = policy_net + + @property + def action_bounds(self) -> Tuple[List[float], List[float]]: + return self._lbounds, self._ubounds + + @property + def policy_net(self) -> ContinuousPolicyNet: + return self._policy_net + + def _post_check(self, states: torch.Tensor, actions: torch.Tensor) -> bool: + return all([ + (np.array(self._lbounds) <= actions.detach().cpu().numpy()).all(), + (actions.detach().cpu().numpy() < np.array(self._ubounds)).all() + ]) + + def _get_actions_impl(self, states: torch.Tensor) -> torch.Tensor: + return self._policy_net.get_actions(states, self._is_exploring) + + def _get_actions_with_probs_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return self._policy_net.get_actions_with_probs(states, self._is_exploring) + + def _get_actions_with_logps_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return self._policy_net.get_actions_with_logps(states, self._is_exploring) + + def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + return self._policy_net.get_states_actions_probs(states, actions) + + def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + return self._policy_net.get_states_actions_logps(states, actions) + + def train_step(self, loss: torch.Tensor) -> None: + self._policy_net.step(loss) + + def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: + return self._policy_net.get_gradients(loss) + + def apply_gradients(self, grad: dict) -> None: + self._policy_net.apply_gradients(grad) + + def freeze(self) -> None: + self._policy_net.freeze() + + def unfreeze(self) -> None: + self._policy_net.unfreeze() + + def eval(self) -> None: + self._policy_net.eval() + + def train(self) -> None: + self._policy_net.train() + + def get_state(self) -> object: + return self._policy_net.get_state() + + def set_state(self, policy_state: dict) -> None: + self._policy_net.set_state(policy_state) + + def soft_update(self, other_policy: RLPolicy, tau: float) -> None: + assert isinstance(other_policy, ContinuousRLPolicy) + self._policy_net.soft_update(other_policy.policy_net, tau) + + def _to_device_impl(self, device: torch.device) -> None: + self._policy_net.to(device) diff --git a/maro/rl/policy/discrete_rl_policy.py b/maro/rl/policy/discrete_rl_policy.py new file mode 100644 index 000000000..42babab50 --- /dev/null +++ b/maro/rl/policy/discrete_rl_policy.py @@ -0,0 +1,336 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import ABCMeta +from typing import Callable, Dict, List, Tuple + +import numpy as np +import torch + +from maro.rl.exploration import epsilon_greedy +from maro.rl.model import DiscretePolicyNet, DiscreteQNet +from maro.rl.utils import match_shape, ndarray_to_tensor +from maro.utils import clone + +from .abs_policy import RLPolicy + + +class DiscreteRLPolicy(RLPolicy, metaclass=ABCMeta): + """RL policy for discrete action spaces. + + Args: + name (str): Name of the policy. + state_dim (int): Dimension of states. + action_num (int): Number of actions. + trainable (bool, default=True): Whether this policy is trainable. + """ + + def __init__( + self, + name: str, + state_dim: int, + action_num: int, + trainable: bool = True, + ) -> None: + assert action_num >= 1 + + super(DiscreteRLPolicy, self).__init__( + name=name, state_dim=state_dim, action_dim=1, trainable=trainable, is_discrete_action=True, + ) + + self._action_num = action_num + + @property + def action_num(self) -> int: + return self._action_num + + def _post_check(self, states: torch.Tensor, actions: torch.Tensor) -> bool: + return all([0 <= action < self.action_num for action in actions.cpu().numpy().flatten()]) + + +class ValueBasedPolicy(DiscreteRLPolicy): + """Valued-based policy. + + Args: + name (str): Name of the policy. + q_net (DiscreteQNet): Q-net used in this value-based policy. + trainable (bool, default=True): Whether this policy is trainable. + exploration_strategy (Tuple[Callable, dict], default=(epsilon_greedy, {"epsilon": 0.1})): Exploration strategy. + exploration_scheduling_options (List[tuple], default=None): List of exploration scheduler options. + warmup (int, default=50000): Minimum number of experiences to warm up this policy. + """ + + def __init__( + self, + name: str, + q_net: DiscreteQNet, + trainable: bool = True, + exploration_strategy: Tuple[Callable, dict] = (epsilon_greedy, {"epsilon": 0.1}), + exploration_scheduling_options: List[tuple] = None, + warmup: int = 50000, + ) -> None: + assert isinstance(q_net, DiscreteQNet) + + super(ValueBasedPolicy, self).__init__( + name=name, state_dim=q_net.state_dim, action_num=q_net.action_num, trainable=trainable, + ) + self._q_net = q_net + + self._exploration_func = exploration_strategy[0] + self._exploration_params = clone(exploration_strategy[1]) # deep copy is needed to avoid unwanted sharing + self._exploration_schedulers = [ + opt[1](self._exploration_params, opt[0], **opt[2]) for opt in exploration_scheduling_options + ] + + self._call_cnt = 0 + self._warmup = warmup + + self._softmax = torch.nn.Softmax(dim=1) + + @property + def q_net(self) -> DiscreteQNet: + return self._q_net + + def q_values_for_all_actions(self, states: np.ndarray) -> np.ndarray: + """Generate a matrix containing the Q-values for all actions for the given states. + + Args: + states (np.ndarray): States. + + Returns: + q_values (np.ndarray): Q-matrix. + """ + return self.q_values_for_all_actions_tensor(ndarray_to_tensor(states, device=self._device)).cpu().numpy() + + def q_values_for_all_actions_tensor(self, states: torch.Tensor) -> torch.Tensor: + """Generate a matrix containing the Q-values for all actions for the given states. + + Args: + states (torch.Tensor): States. + + Returns: + q_values (torch.Tensor): Q-matrix. + """ + assert self._shape_check(states=states) + q_values = self._q_net.q_values_for_all_actions(states) + assert match_shape(q_values, (states.shape[0], self.action_num)) # [B, action_num] + return q_values + + def q_values(self, states: np.ndarray, actions: np.ndarray) -> np.ndarray: + """Generate the Q values for given state-action pairs. + + Args: + states (np.ndarray): States. + actions (np.ndarray): Actions. Should has same length with states. + + Returns: + q_values (np.ndarray): Q-values. + """ + return self.q_values_tensor( + ndarray_to_tensor(states, device=self._device), + ndarray_to_tensor(actions, device=self._device) + ).cpu().numpy() + + def q_values_tensor(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + """Generate the Q values for given state-action pairs. + + Args: + states (torch.Tensor): States. + actions (torch.Tensor): Actions. Should has same length with states. + + Returns: + q_values (torch.Tensor): Q-values. + """ + assert self._shape_check(states=states, actions=actions) # actions: [B, 1] + q_values = self._q_net.q_values(states, actions) + assert match_shape(q_values, (states.shape[0],)) # [B] + return q_values + + def explore(self) -> None: + pass # Overwrite the base method and turn off explore mode. + + def _get_actions_impl(self, states: torch.Tensor) -> torch.Tensor: + actions, _ = self._get_actions_with_probs_impl(states) + return actions + + def _get_actions_with_probs_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + self._call_cnt += 1 + if self._call_cnt <= self._warmup: + actions = ndarray_to_tensor( + np.random.randint(self.action_num, size=(states.shape[0], 1)), + device=self._device, + ) + probs = torch.ones(states.shape[0]).float() * (1.0 / self.action_num) + return actions, probs + + q_matrix = self.q_values_for_all_actions_tensor(states) # [B, action_num] + q_matrix_softmax = self._softmax(q_matrix) + _, actions = q_matrix.max(dim=1) # [B], [B] + + if self._is_exploring: + actions = self._exploration_func(states, actions.cpu().numpy(), self.action_num, **self._exploration_params) + actions = ndarray_to_tensor(actions, device=self._device) + + actions = actions.unsqueeze(1) + return actions, q_matrix_softmax.gather(1, actions).squeeze(-1) # [B, 1] + + def _get_actions_with_logps_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + actions, probs = self._get_actions_with_probs_impl(states) + return actions, torch.log(probs) + + def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + q_matrix = self.q_values_for_all_actions_tensor(states) + q_matrix_softmax = self._softmax(q_matrix) + return q_matrix_softmax.gather(1, actions).squeeze(-1) # [B] + + def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + probs = self._get_states_actions_probs_impl(states, actions) + return torch.log(probs) + + def train_step(self, loss: torch.Tensor) -> None: + return self._q_net.step(loss) + + def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: + return self._q_net.get_gradients(loss) + + def apply_gradients(self, grad: dict) -> None: + self._q_net.apply_gradients(grad) + + def freeze(self) -> None: + self._q_net.freeze() + + def unfreeze(self) -> None: + self._q_net.unfreeze() + + def eval(self) -> None: + self._q_net.eval() + + def train(self) -> None: + self._q_net.train() + + def get_state(self) -> object: + return self._q_net.get_state() + + def set_state(self, policy_state: dict) -> None: + self._q_net.set_state(policy_state) + + def soft_update(self, other_policy: RLPolicy, tau: float) -> None: + assert isinstance(other_policy, ValueBasedPolicy) + self._q_net.soft_update(other_policy.q_net, tau) + + def _to_device_impl(self, device: torch.device) -> None: + self._q_net.to(device) + + +class DiscretePolicyGradient(DiscreteRLPolicy): + """Policy gradient for discrete action spaces. + + Args: + name (str): Name of the policy. + policy_net (DiscretePolicyNet): The core net of this policy. + trainable (bool, default=True): Whether this policy is trainable. + """ + + def __init__( + self, + name: str, + policy_net: DiscretePolicyNet, + trainable: bool = True, + ) -> None: + assert isinstance(policy_net, DiscretePolicyNet) + + super(DiscretePolicyGradient, self).__init__( + name=name, state_dim=policy_net.state_dim, action_num=policy_net.action_num, + trainable=trainable, + ) + + self._policy_net = policy_net + + @property + def policy_net(self) -> DiscretePolicyNet: + return self._policy_net + + def _get_actions_impl(self, states: torch.Tensor) -> torch.Tensor: + return self._policy_net.get_actions(states, self._is_exploring) + + def _get_actions_with_probs_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return self._policy_net.get_actions_with_probs(states, self._is_exploring) + + def _get_actions_with_logps_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return self._policy_net.get_actions_with_logps(states, self._is_exploring) + + def _get_states_actions_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + return self._policy_net.get_states_actions_probs(states, actions) + + def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + return self._policy_net.get_states_actions_logps(states, actions) + + def train_step(self, loss: torch.Tensor) -> None: + self._policy_net.step(loss) + + def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]: + return self._policy_net.get_gradients(loss) + + def apply_gradients(self, grad: dict) -> None: + self._policy_net.apply_gradients(grad) + + def freeze(self) -> None: + self._policy_net.freeze() + + def unfreeze(self) -> None: + self._policy_net.unfreeze() + + def eval(self) -> None: + self._policy_net.eval() + + def train(self) -> None: + self._policy_net.train() + + def get_state(self) -> dict: + return self._policy_net.get_state() + + def set_state(self, policy_state: dict) -> None: + self._policy_net.set_state(policy_state) + + def soft_update(self, other_policy: RLPolicy, tau: float) -> None: + assert isinstance(other_policy, DiscretePolicyGradient) + self._policy_net.soft_update(other_policy.policy_net, tau) + + def get_action_probs(self, states: torch.Tensor) -> torch.Tensor: + """Get the probabilities for all actions according to states. + + Args: + states (torch.Tensor): States. + + Returns: + action_probs (torch.Tensor): Action probabilities with shape [batch_size, action_num]. + """ + assert self._shape_check(states=states), \ + f"States shape check failed. Expecting: {('BATCH_SIZE', self.state_dim)}, actual: {states.shape}." + action_probs = self._policy_net.get_action_probs(states) + assert match_shape(action_probs, (states.shape[0], self.action_num)), \ + f"Action probabilities shape check failed. Expecting: {(states.shape[0], self.action_num)}, " \ + f"actual: {action_probs.shape}." + return action_probs + + def get_action_logps(self, states: torch.Tensor) -> torch.Tensor: + """Get the log-probabilities for all actions according to states. + + Args: + states (torch.Tensor): States. + + Returns: + action_logps (torch.Tensor): Action probabilities with shape [batch_size, action_num]. + """ + return torch.log(self.get_action_probs(states)) + + def _get_state_action_probs_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + action_probs = self.get_action_probs(states) + return action_probs.gather(1, actions).squeeze(-1) # [B] + + def _get_state_action_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + action_logps = self.get_action_logps(states) + return action_logps.gather(1, actions).squeeze(-1) # [B] + + def _to_device_impl(self, device: torch.device) -> None: + self._policy_net.to(device) diff --git a/maro/rl/rl_component/__init__.py b/maro/rl/rl_component/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/maro/rl/rl_component/rl_component_bundle.py b/maro/rl/rl_component/rl_component_bundle.py new file mode 100644 index 000000000..729f1ce85 --- /dev/null +++ b/maro/rl/rl_component/rl_component_bundle.py @@ -0,0 +1,202 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import abstractmethod +from functools import partial +from typing import Any, Callable, Dict, List, Optional + +from maro.rl.policy import AbsPolicy +from maro.rl.rollout import AbsEnvSampler +from maro.rl.training import AbsTrainer +from maro.simulator import Env + + +class RLComponentBundle(object): + """Bundle of all necessary components to run a RL job in MARO. + + Users should create their own subclass of `RLComponentBundle` and implement following methods: + - get_env_config() + - get_test_env_config() + - get_env_sampler() + - get_agent2policy() + - get_policy_creator() + - get_trainer_creator() + + Following methods could be overwritten when necessary: + - get_device_mapping() + + Please refer to the doc string of each method for detailed explanations. + """ + def __init__(self) -> None: + super(RLComponentBundle, self).__init__() + + self.trainer_creator: Optional[Dict[str, Callable[[], AbsTrainer]]] = None + + self.agent2policy: Optional[Dict[Any, str]] = None + self.trainable_agent2policy: Optional[Dict[Any, str]] = None + self.policy_creator: Optional[Dict[str, Callable[[], AbsPolicy]]] = None + self.policy_names: Optional[List[str]] = None + self.trainable_policy_creator: Optional[Dict[str, Callable[[], AbsPolicy]]] = None + self.trainable_policy_names: Optional[List[str]] = None + + self.device_mapping: Optional[Dict[str, str]] = None + self.policy_trainer_mapping: Optional[Dict[str, str]] = None + + self._policy_cache: Optional[Dict[str, AbsPolicy]] = None + + # Will be created when `env_sampler()` is first called + self._env_sampler: Optional[AbsEnvSampler] = None + + self._complete_resources() + + ######################################################################################## + # Users MUST implement the following methods # + ######################################################################################## + @abstractmethod + def get_env_config(self) -> dict: + """Return the environment configuration to build the MARO Env for training. + + Returns: + Environment configuration. + """ + raise NotImplementedError + + @abstractmethod + def get_test_env_config(self) -> Optional[dict]: + """Return the environment configuration to build the MARO Env for testing. If returns `None`, the training + environment will be reused as testing environment. + + Returns: + Environment configuration or `None`. + """ + raise NotImplementedError + + @abstractmethod + def get_env_sampler(self) -> AbsEnvSampler: + """Return the environment sampler of the scenario. + + Returns: + The environment sampler of the scenario. + """ + raise NotImplementedError + + @abstractmethod + def get_agent2policy(self) -> Dict[Any, str]: + """Return agent name to policy name mapping of the RL job. This mapping identifies which policy should + the agents use. For example: {agent1: policy1, agent2: policy1, agent3: policy2}. + + Returns: + Agent name to policy name mapping. + """ + raise NotImplementedError + + @abstractmethod + def get_policy_creator(self) -> Dict[str, Callable[[], AbsPolicy]]: + """Return policy creator. Policy creator is a dictionary that contains a group of functions that generate + policy instances. The key of this dictionary is the policy name, and the value is the function that generate + the corresponding policy instance. Note that the creation function should not take any parameters. + """ + raise NotImplementedError + + @abstractmethod + def get_trainer_creator(self) -> Dict[str, Callable[[], AbsTrainer]]: + """Return trainer creator. Trainer creator is similar to policy creator, but is used to creator trainers. + """ + raise NotImplementedError + + ######################################################################################## + # Users could overwrite the following methods # + ######################################################################################## + def get_device_mapping(self) -> Dict[str, str]: + """Return the device mapping that identifying which device to put each policy. + + If user does not overwrite this method, then all policies will be put on CPU by default. + """ + return {policy_name: "cpu" for policy_name in self.get_policy_creator()} + + def get_policy_trainer_mapping(self) -> Dict[str, str]: + """Return the policy-trainer mapping which identifying which trainer to train each policy. + + If user does not overwrite this method, then a policy's trainer's name is the first segment of the policy's + name, seperated by dot. For example, "ppo_1.policy" is trained by "ppo_1". + + Only policies that provided in policy-trainer mapping are considered as trainable polices. Policies that + not provided in policy-trainer mapping will not be trained since we do not assign a trainer to it. + """ + return { + policy_name: policy_name.split(".")[0] for policy_name in self.policy_creator + } + + ######################################################################################## + # Methods invisible to users # + ######################################################################################## + @property + def env_sampler(self) -> AbsEnvSampler: + if self._env_sampler is None: + self._env_sampler = self.get_env_sampler() + self._env_sampler.build(self) + return self._env_sampler + + def _complete_resources(self) -> None: + """Generate all attributes by calling user-defined logics. Do necessary checking and transformations. + """ + env_config = self.get_env_config() + test_env_config = self.get_test_env_config() + self.env = Env(**env_config) + self.test_env = self.env if test_env_config is None else Env(**test_env_config) + + self.trainer_creator = self.get_trainer_creator() + self.device_mapping = self.get_device_mapping() + + self.policy_creator = self.get_policy_creator() + self.agent2policy = self.get_agent2policy() + + self.policy_trainer_mapping = self.get_policy_trainer_mapping() + + required_policies = set(self.agent2policy.values()) + self.policy_creator = {name: self.policy_creator[name] for name in required_policies} + self.policy_trainer_mapping = { + name: self.policy_trainer_mapping[name] + for name in required_policies + if name in self.policy_trainer_mapping + } + self.policy_names = list(required_policies) + assert len(required_policies) == len(self.policy_creator) # Should have same size after filter + + required_trainers = set(self.policy_trainer_mapping.values()) + self.trainer_creator = {name: self.trainer_creator[name] for name in required_trainers} + assert len(required_trainers) == len(self.trainer_creator) # Should have same size after filter + + self.trainable_policy_names = list(self.policy_trainer_mapping.keys()) + self.trainable_policy_creator = { + policy_name: self.policy_creator[policy_name] + for policy_name in self.trainable_policy_names + } + self.trainable_agent2policy = { + agent_name: policy_name + for agent_name, policy_name in self.agent2policy.items() + if policy_name in self.trainable_policy_names + } + + def pre_create_policy_instances(self) -> None: + """Pre-create policy instances, and return the pre-created policy instances when the external callers + want to create new policies. This will ensure that each policy will have at most one reusable duplicate. + Under specific scenarios (for example, simple training & rollout), this will reduce unnecessary overheads. + """ + old_policy_creator = self.policy_creator + self._policy_cache: Dict[str, AbsPolicy] = {} + for policy_name in self.policy_names: + self._policy_cache[policy_name] = old_policy_creator[policy_name]() + + def _get_policy_instance(policy_name: str) -> AbsPolicy: + return self._policy_cache[policy_name] + + self.policy_creator = { + policy_name: partial(_get_policy_instance, policy_name) + for policy_name in self.policy_names + } + + self.trainable_policy_creator = { + policy_name: self.policy_creator[policy_name] + for policy_name in self.trainable_policy_names + } diff --git a/maro/rl/rollout/__init__.py b/maro/rl/rollout/__init__.py new file mode 100644 index 000000000..41a1995b7 --- /dev/null +++ b/maro/rl/rollout/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from .batch_env_sampler import BatchEnvSampler +from .env_sampler import AbsAgentWrapper, AbsEnvSampler, CacheElement, ExpElement, SimpleAgentWrapper +from .worker import RolloutWorker + +__all__ = [ + "BatchEnvSampler", + "AbsAgentWrapper", "AbsEnvSampler", "CacheElement", "ExpElement", "SimpleAgentWrapper", + "RolloutWorker", +] diff --git a/maro/rl/rollout/batch_env_sampler.py b/maro/rl/rollout/batch_env_sampler.py new file mode 100644 index 000000000..885687d1e --- /dev/null +++ b/maro/rl/rollout/batch_env_sampler.py @@ -0,0 +1,210 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import os +import time +from itertools import chain +from typing import Dict, List, Optional, Tuple + +import torch +import zmq +from zmq import Context, Poller + +from maro.rl.utils.common import bytes_to_pyobj, get_own_ip_address, pyobj_to_bytes +from maro.rl.utils.objects import FILE_SUFFIX +from maro.utils import DummyLogger, LoggerV2 + +from .env_sampler import ExpElement + + +class ParallelTaskController(object): + """Controller that sends identical tasks to a set of remote workers and collect results from them. + + Args: + port (int, default=20000): Network port the controller uses to talk to the remote workers. + logger (LoggerV2, default=None): Optional logger for logging key events. + """ + + def __init__(self, port: int = 20000, logger: LoggerV2 = None) -> None: + self._ip = get_own_ip_address() + self._context = Context.instance() + + # parallel task sender + self._task_endpoint = self._context.socket(zmq.ROUTER) + self._task_endpoint.setsockopt(zmq.LINGER, 0) + self._task_endpoint.bind(f"tcp://{self._ip}:{port}") + + self._poller = Poller() + self._poller.register(self._task_endpoint, zmq.POLLIN) + + self._workers = set() + self._logger = logger + + def _wait_for_workers_ready(self, k: int) -> None: + while len(self._workers) < k: + self._workers.add(self._task_endpoint.recv_multipart()[0]) + + def _recv_result_for_target_index(self, index: int) -> object: + rep = bytes_to_pyobj(self._task_endpoint.recv_multipart()[-1]) + assert isinstance(rep, dict) + return rep["result"] if rep["index"] == index else None + + def collect(self, req: dict, parallelism: int, min_replies: int = None, grace_factor: int = None) -> List[dict]: + """Send a task request to a set of remote workers and collect the results. + + Args: + req (dict): Request containing task specifications and parameters. + parallelism (int): Number of workers to send the task to. + min_replies (int, default=None): The minimum number of results to collect in one round of remote + sampling. If None, it defaults to the value of ``parallelism``. + grace_factor (float, default=None): Factor that determines the additional wait time after receiving the + minimum required replies (as determined by ``min_replies``). For example, if the minimum required + replies are received in T seconds, it will allow an additional T * grace_factor seconds to collect + the remaining results. + + Returns: + A list of results. Each element in the list is a dict that contains results from a worker. + """ + self._wait_for_workers_ready(parallelism) + if min_replies is None: + min_replies = parallelism + + start_time = time.time() + results = [] + for worker_id in list(self._workers)[:parallelism]: + self._task_endpoint.send_multipart([worker_id, pyobj_to_bytes(req)]) + self._logger.debug(f"Sent {parallelism} roll-out requests...") + + while len(results) < min_replies: + result = self._recv_result_for_target_index(req["index"]) + if result: + results.append(result) + + if grace_factor is not None: + countdown = int((time.time() - start_time) * grace_factor) * 1000 # milliseconds + self._logger.debug(f"allowing {countdown / 1000} seconds for remaining results") + while len(results) < parallelism and countdown > 0: + start = time.time() + event = dict(self._poller.poll(countdown)) + if self._task_endpoint in event: + result = self._recv_result_for_target_index(req["index"]) + if result: + results.append(result) + countdown -= time.time() - start + + self._logger.debug(f"Received {len(results)} results") + return results + + def exit(self) -> None: + """Signal the remote workers to exit and terminate the connections. + """ + for worker_id in self._workers: + self._task_endpoint.send_multipart([worker_id, b"EXIT"]) + self._task_endpoint.close() + self._context.term() + + +class BatchEnvSampler: + """Facility that samples from multiple copies of an environment in parallel. + + No environment is created here. Instead, it uses a ParallelTaskController to send roll-out requests to a set of + remote workers and collect results from them. + + Args: + sampling_parallelism (int): Parallelism for sampling from the environment. + port (int): Network port that the internal ``ParallelTaskController`` uses to talk to the remote workers. + min_env_samples (int, default=None): The minimum number of results to collect in one round of remote sampling. + If it is None, it defaults to the value of ``sampling_parallelism``. + grace_factor (float, default=None): Factor that determines the additional wait time after receiving the minimum + required env samples (as determined by ``min_env_samples``). For example, if the minimum required samples + are received in T seconds, it will allow an additional T * grace_factor seconds to collect the remaining + results. + eval_parallelism (int, default=None): Parallelism for policy evaluation on remote workers. + logger (LoggerV2, default=None): Optional logger for logging key events. + """ + + def __init__( + self, + sampling_parallelism: int, + port: int = 20000, + min_env_samples: int = None, + grace_factor: float = None, + eval_parallelism: int = None, + logger: LoggerV2 = None, + ) -> None: + super(BatchEnvSampler, self).__init__() + self._logger = logger if logger else DummyLogger() + self._controller = ParallelTaskController(port=port, logger=logger) + + self._sampling_parallelism = 1 if sampling_parallelism is None else sampling_parallelism + self._min_env_samples = min_env_samples if min_env_samples is not None else self._sampling_parallelism + self._grace_factor = grace_factor + self._eval_parallelism = 1 if eval_parallelism is None else eval_parallelism + + self._ep = 0 + self._end_of_episode = True + + def sample(self, policy_state: Optional[Dict[str, object]] = None, num_steps: Optional[int] = None) -> dict: + """Collect experiences from a set of remote roll-out workers. + + Args: + policy_state (Dict[str, object]): Policy state dict. If it is not None, then we need to update all + policies according to the latest policy states, then start the experience collection. + num_steps (Optional[int], default=None): Number of environment steps to collect experiences for. If + it is None, interactions with the (remote) environments will continue until the terminal state is + reached. + + Returns: + A dict that contains the collected experiences and additional information. + """ + # increment episode depending on whether the last episode has concluded + if self._end_of_episode: + self._ep += 1 + + self._logger.info(f"Collecting roll-out data for episode {self._ep}") + req = { + "type": "sample", + "policy_state": policy_state, + "num_steps": num_steps, + "index": self._ep, + } + results = self._controller.collect( + req, self._sampling_parallelism, + min_replies=self._min_env_samples, + grace_factor=self._grace_factor, + ) + self._end_of_episode = any(res["end_of_episode"] for res in results) + merged_experiences: List[List[ExpElement]] = list(chain(*[res["experiences"] for res in results])) + return { + "end_of_episode": self._end_of_episode, + "experiences": merged_experiences, + "info": [res["info"][0] for res in results], + } + + def eval(self, policy_state: Dict[str, object] = None) -> dict: + req = {"type": "eval", "policy_state": policy_state, "index": self._ep} # -1 signals test + results = self._controller.collect(req, self._eval_parallelism) + return { + "info": [res["info"][0] for res in results], + } + + def load_policy_state(self, path: str) -> List[str]: + file_list = os.listdir(path) + policy_state_dict = {} + loaded = [] + for file_name in file_list: + if "non_policy" in file_name or not file_name.endswith(f"_policy.{FILE_SUFFIX}"): # TODO: remove hardcode + continue + policy_name, policy_state = torch.load(os.path.join(path, file_name)) + policy_state_dict[policy_name] = policy_state + loaded.append(policy_name) + + req = { + "type": "set_policy_state", + "policy_state": policy_state_dict, + "index": self._ep, + } + self._controller.collect(req, self._sampling_parallelism) + return loaded + + def exit(self) -> None: + self._controller.exit() diff --git a/maro/rl/rollout/env_sampler.py b/maro/rl/rollout/env_sampler.py new file mode 100644 index 000000000..f396b2342 --- /dev/null +++ b/maro/rl/rollout/env_sampler.py @@ -0,0 +1,562 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import collections +import os +import typing +from abc import ABCMeta, abstractmethod +from copy import deepcopy +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import torch + +from maro.rl.policy import AbsPolicy, RLPolicy +from maro.rl.utils.objects import FILE_SUFFIX +from maro.simulator import Env + +if typing.TYPE_CHECKING: + from maro.rl.rl_component.rl_component_bundle import RLComponentBundle + + +class AbsAgentWrapper(object, metaclass=ABCMeta): + """Agent wrapper. Used to manager agents & policies during experience collection. + + Args: + policy_dict (Dict[str, AbsPolicy]): Dictionary that maps policy names to policy instances. + agent2policy (Dict[Any, str]): Agent name to policy name mapping. + """ + + def __init__( + self, + policy_dict: Dict[str, AbsPolicy], # {policy_name: AbsPolicy} + agent2policy: Dict[Any, str], # {agent_name: policy_name} + ) -> None: + self._policy_dict = policy_dict + self._agent2policy = agent2policy + + def set_policy_state(self, policy_state_dict: Dict[str, dict]) -> None: + """Set policies' states. + + Args: + policy_state_dict (Dict[str, dict]): Double-deck dict with format: {policy_name: policy_state}. + """ + for policy_name, policy_state in policy_state_dict.items(): + policy = self._policy_dict[policy_name] + if isinstance(policy, RLPolicy): + policy.set_state(policy_state) + + def choose_actions( + self, state_by_agent: Dict[Any, Union[np.ndarray, List[object]]] + ) -> Dict[Any, Union[np.ndarray, List[object]]]: + """Choose action according to the given (observable) states of all agents. + + Args: + state_by_agent (Dict[Any, Union[np.ndarray, List[object]]]): Dictionary containing each agent's states. + If the policy is a `RLPolicy`, its state is a Numpy array. Otherwise, its state is a list of objects. + + Returns: + actions (Dict[Any, Union[np.ndarray, List[object]]]): Dict that contains the action for all agents. + If the policy is a `RLPolicy`, its action is a Numpy array. Otherwise, its action is a list of objects. + """ + self.switch_to_eval_mode() + with torch.no_grad(): + ret = self._choose_actions_impl(state_by_agent) + return ret + + @abstractmethod + def _choose_actions_impl( + self, state_by_agent: Dict[Any, Union[np.ndarray, List[object]]], + ) -> Dict[Any, Union[np.ndarray, List[object]]]: + """Implementation of `choose_actions`. + """ + raise NotImplementedError + + @abstractmethod + def explore(self) -> None: + """Switch all policies to exploration mode. + """ + raise NotImplementedError + + @abstractmethod + def exploit(self) -> None: + """Switch all policies to exploitation mode. + """ + raise NotImplementedError + + @abstractmethod + def switch_to_eval_mode(self) -> None: + """Switch the environment sampler to evaluation mode. + """ + pass + + +class SimpleAgentWrapper(AbsAgentWrapper): + def __init__( + self, + policy_dict: Dict[str, RLPolicy], # {policy_name: RLPolicy} + agent2policy: Dict[Any, str], # {agent_name: policy_name} + ) -> None: + super(SimpleAgentWrapper, self).__init__(policy_dict=policy_dict, agent2policy=agent2policy) + + def _choose_actions_impl( + self, state_by_agent: Dict[Any, Union[np.ndarray, List[object]]], + ) -> Dict[Any, Union[np.ndarray, List[object]]]: + # Aggregate states by policy + states_by_policy = collections.defaultdict(list) # {str: list of np.ndarray} + agents_by_policy = collections.defaultdict(list) # {str: list of str} + for agent_name, state in state_by_agent.items(): + policy_name = self._agent2policy[agent_name] + states_by_policy[policy_name].append(state) + agents_by_policy[policy_name].append(agent_name) + + action_dict = {} + for policy_name in agents_by_policy: + policy = self._policy_dict[policy_name] + + if isinstance(policy, RLPolicy): + states = np.vstack(states_by_policy[policy_name]) # np.ndarray + else: + states = states_by_policy[policy_name] # List[object] + actions = policy.get_actions(states) # np.ndarray or List[object] + action_dict.update(zip(agents_by_policy[policy_name], actions)) + + return action_dict + + def explore(self) -> None: + for policy in self._policy_dict.values(): + policy.explore() + + def exploit(self) -> None: + for policy in self._policy_dict.values(): + policy.exploit() + + def switch_to_eval_mode(self) -> None: + for policy in self._policy_dict.values(): + policy.eval() + + +@dataclass +class ExpElement: + """Stores the complete information for a tick. + """ + tick: int + state: np.ndarray + agent_state_dict: Dict[Any, np.ndarray] + action_dict: Dict[Any, np.ndarray] + reward_dict: Dict[Any, float] + terminal_dict: Dict[Any, bool] + next_state: Optional[np.ndarray] + next_agent_state_dict: Dict[Any, np.ndarray] + + @property + def agent_names(self) -> list: + return sorted(self.agent_state_dict.keys()) + + @property + def num_agents(self) -> int: + return len(self.agent_state_dict) + + def split_contents_by_agent(self) -> Dict[Any, ExpElement]: + ret = {} + for agent_name in self.agent_state_dict.keys(): + ret[agent_name] = ExpElement( + tick=self.tick, + state=self.state, + agent_state_dict={agent_name: self.agent_state_dict[agent_name]}, + action_dict={agent_name: self.action_dict[agent_name]}, + reward_dict={agent_name: self.reward_dict[agent_name]}, + terminal_dict={agent_name: self.terminal_dict[agent_name]}, + next_state=self.next_state, + next_agent_state_dict={ + agent_name: self.next_agent_state_dict[agent_name] + } if self.next_agent_state_dict is not None and agent_name in self.next_agent_state_dict else {}, + ) + return ret + + def split_contents_by_trainer(self, agent2trainer: Dict[Any, str]) -> Dict[str, ExpElement]: + """Split the ExpElement's contents by trainer. + + Args: + agent2trainer (Dict[Any, str]): Mapping of agent name and trainer name. + + Returns: + Contents (Dict[str, ExpElement]): A dict that contains the ExpElements of all trainers. The key of this + dict is the trainer name. + """ + ret = collections.defaultdict(lambda: ExpElement( + tick=self.tick, + state=self.state, + agent_state_dict={}, + action_dict={}, + reward_dict={}, + terminal_dict={}, + next_state=self.next_state, + next_agent_state_dict=None if self.next_agent_state_dict is None else {}, + )) + for agent_name, trainer_name in agent2trainer.items(): + if agent_name in self.agent_state_dict: + ret[trainer_name].agent_state_dict[agent_name] = self.agent_state_dict[agent_name] + ret[trainer_name].action_dict[agent_name] = self.action_dict[agent_name] + ret[trainer_name].reward_dict[agent_name] = self.reward_dict[agent_name] + ret[trainer_name].terminal_dict[agent_name] = self.terminal_dict[agent_name] + if self.next_agent_state_dict is not None and agent_name in self.next_agent_state_dict: + ret[trainer_name].next_agent_state_dict[agent_name] = self.next_agent_state_dict[agent_name] + return ret + + +@dataclass +class CacheElement(ExpElement): + event: object + env_action_dict: Dict[Any, np.ndarray] + + def make_exp_element(self) -> ExpElement: + assert len(self.terminal_dict) == len(self.agent_state_dict) == len(self.action_dict) + assert len(self.terminal_dict) == len(self.next_agent_state_dict) == len(self.reward_dict) + + return ExpElement( + tick=self.tick, + state=self.state, + agent_state_dict=self.agent_state_dict, + action_dict=self.action_dict, + reward_dict=self.reward_dict, + terminal_dict=self.terminal_dict, + next_state=self.next_state, + next_agent_state_dict=self.next_agent_state_dict, + ) + + +class AbsEnvSampler(object, metaclass=ABCMeta): + """Simulation data collector and policy evaluator. + + Args: + learn_env (Env): Environment used for training. + test_env (Env): Environment used for testing. + agent_wrapper_cls (Type[AbsAgentWrapper], default=SimpleAgentWrapper): Specific AgentWrapper type. + reward_eval_delay (int, default=None): Number of ticks required after a decision event to evaluate the reward + for the action taken for that event. If it is None, calculate reward immediately after `step()`. + """ + + def __init__( + self, + learn_env: Env, + test_env: Env, + agent_wrapper_cls: Type[AbsAgentWrapper] = SimpleAgentWrapper, + reward_eval_delay: int = None, + ) -> None: + self._learn_env = learn_env + self._test_env = test_env + + self._agent_wrapper_cls = agent_wrapper_cls + + self._event = None + self._end_of_episode = True + self._state: Optional[np.ndarray] = None + self._agent_state_dict: Dict[Any, np.ndarray] = {} + + self._trans_cache: List[CacheElement] = [] + self._agent_last_index: Dict[Any, int] = {} # Index of last occurrence of agent in self._trans_cache + self._reward_eval_delay = reward_eval_delay + + self._info = {} + + assert self._reward_eval_delay is None or self._reward_eval_delay >= 0 + + def build( + self, + rl_component_bundle: RLComponentBundle, + ) -> None: + """ + Args: + rl_component_bundle (RLComponentBundle): The RL component bundle of the job. + """ + self._env: Optional[Env] = None + + self._policy_dict = { + policy_name: rl_component_bundle.policy_creator[policy_name]() + for policy_name in rl_component_bundle.policy_names + } + + self._rl_policy_dict: Dict[str, RLPolicy] = { + name: policy for name, policy in self._policy_dict.items() if isinstance(policy, RLPolicy) + } + self._agent2policy = rl_component_bundle.agent2policy + self._agent_wrapper = self._agent_wrapper_cls(self._policy_dict, self._agent2policy) + self._trainable_policies = set(rl_component_bundle.trainable_policy_names) + self._trainable_agents = { + agent_id for agent_id, policy_name in self._agent2policy.items() if policy_name in self._trainable_policies + } + + assert all([policy_name in self._rl_policy_dict for policy_name in self._trainable_policies]), \ + "All trainable policies must be RL policies!" + + def assign_policy_to_device(self, policy_name: str, device: torch.device) -> None: + self._rl_policy_dict[policy_name].to_device(device) + + def _get_global_and_agent_state( + self, event: object, tick: int = None, + ) -> Tuple[Optional[object], Dict[Any, Union[np.ndarray, List[object]]]]: + """Get the global and individual agents' states. + + Args: + event (object): Event. + tick (int, default=None): Current tick. + + Returns: + Global state (Optional[object]) + Dict of agent states (Dict[Any, Union[np.ndarray, List[object]]]). If the policy is a `RLPolicy`, + its state is a Numpy array. Otherwise, its state is a list of objects. + """ + global_state, agent_state_dict = self._get_global_and_agent_state_impl(event, tick) + for agent_name, state in agent_state_dict.items(): + policy_name = self._agent2policy[agent_name] + policy = self._policy_dict[policy_name] + if isinstance(policy, RLPolicy) and not isinstance(state, np.ndarray): + raise ValueError(f"Agent {agent_name} uses a RLPolicy but its state is not a np.ndarray.") + return global_state, agent_state_dict + + @abstractmethod + def _get_global_and_agent_state_impl( + self, event: object, tick: int = None, + ) -> Tuple[Union[None, np.ndarray, List[object]], Dict[Any, Union[np.ndarray, List[object]]]]: + raise NotImplementedError + + @abstractmethod + def _translate_to_env_action( + self, action_dict: Dict[Any, Union[np.ndarray, List[object]]], event: object, + ) -> Dict[Any, object]: + """Translate model-generated actions into an object that can be executed by the env. + + Args: + action_dict (Dict[Any, Union[np.ndarray, List[object]]]): Action for all agents. If the policy is a + `RLPolicy`, its (input) action is a Numpy array. Otherwise, its (input) action is a list of objects. + event (object): Decision event. + + Returns: + A dict that contains env actions for all agents. + """ + raise NotImplementedError + + @abstractmethod + def _get_reward(self, env_action_dict: Dict[Any, object], event: object, tick: int) -> Dict[Any, float]: + """Get rewards according to the env actions. + + Args: + env_action_dict (Dict[Any, object]): Dict that contains env actions for all agents. + event (object): Decision event. + tick (int): Current tick. + + Returns: + A dict that contains rewards for all agents. + """ + raise NotImplementedError + + def _step(self, actions: Optional[list]) -> None: + _, self._event, self._end_of_episode = self._env.step(actions) + self._state, self._agent_state_dict = (None, {}) \ + if self._end_of_episode else self._get_global_and_agent_state(self._event) + + def _calc_reward(self, cache_element: CacheElement) -> None: + cache_element.reward_dict = self._get_reward( + cache_element.env_action_dict, cache_element.event, cache_element.tick, + ) + + def _append_cache_element(self, cache_element: Optional[CacheElement]) -> None: + """`cache_element` == None means we are processing the last element in trans_cache""" + if cache_element is None: + if len(self._trans_cache) > 0: + self._trans_cache[-1].next_state = self._trans_cache[-1].state + + for agent_name, i in self._agent_last_index.items(): + e = self._trans_cache[i] + e.terminal_dict[agent_name] = self._end_of_episode + e.next_agent_state_dict[agent_name] = e.agent_state_dict[agent_name] + else: + self._trans_cache.append(cache_element) + + if len(self._trans_cache) > 0: + self._trans_cache[-1].next_state = cache_element.state + + cur_index = len(self._trans_cache) - 1 + for agent_name in cache_element.agent_names: + if agent_name in self._agent_last_index: + i = self._agent_last_index[agent_name] + self._trans_cache[i].terminal_dict[agent_name] = False + self._trans_cache[i].next_agent_state_dict[agent_name] = cache_element.agent_state_dict[agent_name] + self._agent_last_index[agent_name] = cur_index + + def _reset(self) -> None: + self._env.reset() + self._info.clear() + self._trans_cache.clear() + self._agent_last_index.clear() + self._step(None) + + def _select_trainable_agents(self, original_dict: dict) -> dict: + return { + k: v + for k, v in original_dict.items() + if k in self._trainable_agents + } + + def sample(self, policy_state: Optional[Dict[str, dict]] = None, num_steps: Optional[int] = None) -> dict: + """Sample experiences. + + Args: + policy_state (Dict[str, dict]): Policy state dict. If it is not None, then we need to update all + policies according to the latest policy states, then start the experience collection. + num_steps (Optional[int], default=None): Number of collecting steps. If it is None, interactions with + the environment will continue until the terminal state is reached. + + Returns: + A dict that contains the collected experiences and additional information. + """ + # Init the env + self._env = self._learn_env + if self._end_of_episode: + self._reset() + + # Update policy state if necessary + if policy_state is not None: + self.set_policy_state(policy_state) + + # Collect experience + self._agent_wrapper.explore() + steps_to_go = float("inf") if num_steps is None else num_steps + while not self._end_of_episode and steps_to_go > 0: + # Get agent actions and translate them to env actions + action_dict = self._agent_wrapper.choose_actions(self._agent_state_dict) + env_action_dict = self._translate_to_env_action(action_dict, self._event) + + # Store experiences in the cache + cache_element = CacheElement( + tick=self._env.tick, + event=self._event, + state=self._state, + agent_state_dict=self._select_trainable_agents(self._agent_state_dict), + action_dict=self._select_trainable_agents(action_dict), + env_action_dict=self._select_trainable_agents(env_action_dict), + # The following will be generated later + reward_dict={}, + terminal_dict={}, + next_state=None, + next_agent_state_dict={}, + ) + + # Update env and get new states (global & agent) + self._step(list(env_action_dict.values())) + + if self._reward_eval_delay is None: + self._calc_reward(cache_element) + self._post_step(cache_element) + self._append_cache_element(cache_element) + steps_to_go -= 1 + self._append_cache_element(None) + + tick_bound = self._env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay) + experiences: List[ExpElement] = [] + while len(self._trans_cache) > 0 and self._trans_cache[0].tick <= tick_bound: + cache_element = self._trans_cache.pop(0) + # !: Here the reward calculation method requires the given tick is enough and must be used then. + if self._reward_eval_delay is not None: + self._calc_reward(cache_element) + self._post_step(cache_element) + experiences.append(cache_element.make_exp_element()) + + self._agent_last_index = { + k: v - len(experiences) + for k, v in self._agent_last_index.items() + if v >= len(experiences) + } + + return { + "end_of_episode": self._end_of_episode, + "experiences": [experiences], + "info": [deepcopy(self._info)], # TODO: may have overhead issues. Leave to future work. + } + + def set_policy_state(self, policy_state_dict: Dict[str, dict]) -> None: + """Set policies' states. + + Args: + policy_state_dict (Dict[str, dict]): Double-deck dict with format: {policy_name: policy_state}. + """ + self._agent_wrapper.set_policy_state(policy_state_dict) + + def load_policy_state(self, path: str) -> List[str]: + file_list = os.listdir(path) + policy_state_dict = {} + loaded = [] + for file_name in file_list: + if "non_policy" in file_name or not file_name.endswith(f"_policy.{FILE_SUFFIX}"): # TODO: remove hardcode + continue + policy_name, policy_state = torch.load(os.path.join(path, file_name)) + policy_state_dict[policy_name] = policy_state + loaded.append(policy_name) + self.set_policy_state(policy_state_dict) + + return loaded + + def eval(self, policy_state: Dict[str, dict] = None) -> dict: + self._env = self._test_env + self._reset() + if policy_state is not None: + self.set_policy_state(policy_state) + + self._agent_wrapper.exploit() + while not self._end_of_episode: + action_dict = self._agent_wrapper.choose_actions(self._agent_state_dict) + env_action_dict = self._translate_to_env_action(action_dict, self._event) + + # Store experiences in the cache + cache_element = CacheElement( + tick=self._env.tick, + event=self._event, + state=self._state, + agent_state_dict=self._select_trainable_agents(self._agent_state_dict), + action_dict=self._select_trainable_agents(action_dict), + env_action_dict=self._select_trainable_agents(env_action_dict), + # The following will be generated later + reward_dict={}, + terminal_dict={}, + next_state=None, + next_agent_state_dict={}, + ) + + # Update env and get new states (global & agent) + self._step(list(env_action_dict.values())) + + if self._reward_eval_delay is None: # TODO: necessary to calculate reward in eval()? + self._calc_reward(cache_element) + self._post_eval_step(cache_element) + + self._append_cache_element(cache_element) + self._append_cache_element(None) + + tick_bound = self._env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay) + while len(self._trans_cache) > 0 and self._trans_cache[0].tick <= tick_bound: + cache_element = self._trans_cache.pop(0) + if self._reward_eval_delay is not None: + self._calc_reward(cache_element) + self._post_eval_step(cache_element) + + return {"info": [self._info]} + + @abstractmethod + def _post_step(self, cache_element: CacheElement) -> None: + raise NotImplementedError + + @abstractmethod + def _post_eval_step(self, cache_element: CacheElement) -> None: + raise NotImplementedError + + def post_collect(self, info_list: list, ep: int) -> None: + """Routines to be invoked at the end of training episodes""" + pass + + def post_evaluate(self, info_list: list, ep: int) -> None: + """Routines to be invoked at the end of evaluation episodes""" + pass diff --git a/maro/rl/rollout/worker.py b/maro/rl/rollout/worker.py new file mode 100644 index 000000000..db1f3fea8 --- /dev/null +++ b/maro/rl/rollout/worker.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import typing + +from maro.rl.distributed import AbsWorker +from maro.rl.utils.common import bytes_to_pyobj, pyobj_to_bytes +from maro.utils import LoggerV2 + +if typing.TYPE_CHECKING: + from maro.rl.rl_component.rl_component_bundle import RLComponentBundle + + +class RolloutWorker(AbsWorker): + """Worker that hosts an environment simulator and executes roll-out on demand for sampling and evaluation purposes. + + Args: + idx (int): Integer identifier for the worker. It is used to generate an internal ID, "worker.{idx}", + so that the parallel roll-out controller can keep track of its connection status. + rl_component_bundle (RLComponentBundle): The RL component bundle of the job. + producer_host (str): IP address of the parallel task controller host to connect to. + producer_port (int, default=20000): Port of the parallel task controller host to connect to. + logger (LoggerV2, default=None): The logger of the workflow. + """ + + def __init__( + self, + idx: int, + rl_component_bundle: RLComponentBundle, + producer_host: str, + producer_port: int = 20000, + logger: LoggerV2 = None, + ) -> None: + super(RolloutWorker, self).__init__( + idx=idx, producer_host=producer_host, producer_port=producer_port, logger=logger, + ) + self._env_sampler = rl_component_bundle.env_sampler + + def _compute(self, msg: list) -> None: + """Perform a full or partial episode of roll-out for sampling or evaluation. + + Args: + msg (list): Multi-part message containing roll-out specifications and parameters. + """ + if msg[-1] == b"EXIT": + self._logger.info("Exiting event loop...") + self.stop() + else: + req = bytes_to_pyobj(msg[-1]) + assert isinstance(req, dict) + assert req["type"] in {"sample", "eval", "set_policy_state"} + if req["type"] == "sample": + result = self._env_sampler.sample(policy_state=req["policy_state"], num_steps=req["num_steps"]) + elif req["type"] == "eval": + result = self._env_sampler.eval(policy_state=req["policy_state"]) + else: + self._env_sampler.set_policy_state(policy_state_dict=req["policy_state"]) + result = True + + self._stream.send(pyobj_to_bytes({"result": result, "index": req["index"]})) diff --git a/maro/rl/scheduling/__init__.py b/maro/rl/scheduling/__init__.py deleted file mode 100644 index 1b5c46b3b..000000000 --- a/maro/rl/scheduling/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from .scheduler import Scheduler -from .simple_parameter_scheduler import LinearParameterScheduler, TwoPhaseLinearParameterScheduler - -__all__ = [ - "Scheduler", - "LinearParameterScheduler", - "TwoPhaseLinearParameterScheduler" -] diff --git a/maro/rl/scheduling/scheduler.py b/maro/rl/scheduling/scheduler.py deleted file mode 100644 index 75d22c702..000000000 --- a/maro/rl/scheduling/scheduler.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -class Scheduler(object): - """Scheduler that generates new parameters each iteration. - - Args: - max_iter (int): Maximum number of iterations. If -1, using the scheduler in a for-loop - will result in an infinite loop unless the ``check_for_stopping`` method is implemented. - """ - - def __init__(self, max_iter: int = -1): - if max_iter <= 0 and max_iter != -1: - raise ValueError("max_iter must be a positive integer or -1.") - self._max_iter = max_iter - self._iter_index = -1 - - def __iter__(self): - return self - - def __next__(self): - self._iter_index += 1 - if self._iter_index == self._max_iter or self.check_for_stopping(): - raise StopIteration - - return self.next_params() - - def next_params(self): - pass - - def check_for_stopping(self) -> bool: - return False - - @property - def iter(self): - return self._iter_index diff --git a/maro/rl/scheduling/simple_parameter_scheduler.py b/maro/rl/scheduling/simple_parameter_scheduler.py deleted file mode 100644 index 0e4e04550..000000000 --- a/maro/rl/scheduling/simple_parameter_scheduler.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from typing import List, Union - -import numpy as np - -from .scheduler import Scheduler - - -class LinearParameterScheduler(Scheduler): - """Static exploration parameter generator based on a linear schedule. - - Args: - max_iter (int): Maximum number of iterations. - parameter_names (List[str]): List of exploration parameter names. - start (Union[float, list, tuple, np.ndarray]): Exploration parameter values for the first episode. - These values must correspond to ``parameter_names``. - end (Union[float, list, tuple, np.ndarray]): Exploration parameter values rate for the last episode. - These values must correspond to ``parameter_names``. - """ - def __init__( - self, - max_iter: int, - parameter_names: List[str], - start: Union[float, list, tuple, np.ndarray], - end: Union[float, list, tuple, np.ndarray] - ): - super().__init__(max_iter) - self._parameter_names = parameter_names - if isinstance(start, float): - self._current_values = start * np.ones(len(self._parameter_names)) - elif isinstance(start, (list, tuple)): - self._current_values = np.asarray(start) - else: - self._current_values = start - - if isinstance(end, float): - end = end * np.ones(len(self._parameter_names)) - elif isinstance(end, (list, tuple)): - end = np.asarray(end) - - self._delta = (end - self._current_values) / (self._max_iter - 1) - - def next_params(self): - current_values = self._current_values.copy() - self._current_values += self._delta - return dict(zip(self._parameter_names, current_values)) - - -class TwoPhaseLinearParameterScheduler(Scheduler): - """Exploration parameter generator based on two linear schedules joined together. - - Args: - max_iter (int): Maximum number of iterations. - parameter_names (List[str]): List of exploration parameter names. - split (float): The point where the switch from the first linear schedule to the second occurs. - start (Union[float, list, tuple, np.ndarray]): Exploration parameter values for the first episode. - These values must correspond to ``parameter_names``. - mid (Union[float, list, tuple, np.ndarray]): Exploration parameter values where the switch from the - first linear schedule to the second occurs. In other words, this is the exploration rate where the first - linear schedule ends and the second begins. These values must correspond to ``parameter_names``. - end (Union[float, list, tuple, np.ndarray]): Exploration parameter values for the last episode. - These values must correspond to ``parameter_names``. - - Returns: - An iterator over the series of exploration rates from episode 0 to ``max_iter`` - 1. - """ - def __init__( - self, - max_iter: int, - parameter_names: List[str], - split: float, - start: Union[float, list, tuple, np.ndarray], - mid: Union[float, list, tuple, np.ndarray], - end: Union[float, list, tuple, np.ndarray] - ): - if split < 0 or split > 1.0: - raise ValueError("split must be a float between 0 and 1.") - super().__init__(max_iter) - self._parameter_names = parameter_names - self._split = int(self._max_iter * split) - if isinstance(start, float): - self._current_values = start * np.ones(len(self._parameter_names)) - elif isinstance(start, (list, tuple)): - self._current_values = np.asarray(start) - else: - self._current_values = start - - if isinstance(mid, float): - mid = mid * np.ones(len(self._parameter_names)) - elif isinstance(mid, (list, tuple)): - mid = np.asarray(mid) - - if isinstance(end, float): - end = end * np.ones(len(self._parameter_names)) - elif isinstance(end, (list, tuple)): - end = np.asarray(end) - - self._delta_1 = (mid - self._current_values) / self._split - self._delta_2 = (end - mid) / (max_iter - self._split - 1) - - def next_params(self): - current_values = self._current_values.copy() - self._current_values += self._delta_1 if self._iter_index < self._split else self._delta_2 - return dict(zip(self._parameter_names, current_values)) diff --git a/maro/rl/storage/__init__.py b/maro/rl/storage/__init__.py deleted file mode 100644 index 4ea19a059..000000000 --- a/maro/rl/storage/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from .abs_store import AbsStore -from .simple_store import OverwriteType, SimpleStore - -__all__ = ["AbsStore", "OverwriteType", "SimpleStore"] diff --git a/maro/rl/storage/abs_store.py b/maro/rl/storage/abs_store.py deleted file mode 100644 index 0c32fb1e4..000000000 --- a/maro/rl/storage/abs_store.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from abc import ABC, abstractmethod -from typing import Callable, Sequence - - -class AbsStore(ABC): - """A data store abstraction that supports get, put, update and sample operations.""" - def __init__(self): - pass - - @abstractmethod - def get(self, indexes: Sequence): - """Get contents. - - Args: - indexes: A sequence of indexes to retrieve contents at. - Returns: - Retrieved contents. - """ - pass - - def put(self, contents: Sequence): - """Put new contents. - - Args: - contents (Sequence): Contents to be added to the store. - Returns: - The indexes where the newly added entries reside in the store. - """ - pass - - @abstractmethod - def update(self, indexes: Sequence, contents: Sequence): - """Update the store contents at given positions. - - Args: - indexes (Sequence): Positions where updates are to be made. - contents (Sequence): Item list, which has the same length as indexes. - Returns: - The indexes where store contents are updated. - """ - pass - - def filter(self, filters: Sequence[Callable]): - """Multi-filter method. - - The input to one filter is the output from the previous filter. - - Args: - filters (Sequence[Callable]): Filter list, each item is a lambda function, - e.g., [lambda d: d['a'] == 1 and d['b'] == 1]. - Returns: - Filtered indexes and corresponding objects. - """ - pass - - @abstractmethod - def sample(self, size: int, weights: Sequence, replace: bool = True): - """Obtain a random sample from the experience pool. - - Args: - size (int): Sample sizes for each round of sampling in the chain. If this is a single integer, it is - used as the sample size for all samplers in the chain. - weights (Sequence): A sequence of sampling weights. - replace (bool): If True, sampling is performed with replacement. Defaults to True. - Returns: - A random sample from the experience pool. - """ - pass diff --git a/maro/rl/storage/simple_store.py b/maro/rl/storage/simple_store.py deleted file mode 100644 index 9632cf9d6..000000000 --- a/maro/rl/storage/simple_store.py +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from enum import Enum -from typing import Callable, Dict, List, Tuple, Union - -import numpy as np - -from maro.utils import clone -from maro.utils.exception.rl_toolkit_exception import StoreMisalignment - -from .abs_store import AbsStore - - -class OverwriteType(Enum): - ROLLING = "rolling" - RANDOM = "random" - - -class SimpleStore(AbsStore): - """ - An implementation of ``AbsStore`` for experience storage in RL. - - This implementation uses a dictionary of lists as the internal data structure. The objects for each key - are stored in a list. To be useful for experience storage in RL, uniformity checks are performed during - put operations to ensure that the list lengths stay the same for all keys at all times. Both unlimited - and limited storage are supported. - - Args: - keys (list): List of keys identifying each column. - capacity (int): If negative, the store is of unlimited capacity. Defaults to -1. - overwrite_type (OverwriteType): If storage capacity is bounded, this specifies how existing entries - are overwritten when the capacity is exceeded. Two types of overwrite behavior are supported: - - Rolling, where overwrite occurs sequentially with wrap-around. - - Random, where overwrite occurs randomly among filled positions. - Alternatively, the user may also specify overwrite positions (see ``put``). - """ - def __init__(self, keys: list, capacity: int = -1, overwrite_type: OverwriteType = None): - super().__init__() - self._keys = keys - self._capacity = capacity - self._overwrite_type = overwrite_type - self._store = {key: [] if self._capacity < 0 else [None] * self._capacity for key in keys} - self._size = 0 - self._iter_index = 0 - - def __len__(self): - return self._size - - def __iter__(self): - return self - - def __next__(self): - if self._iter_index >= self._size: - self._iter_index = 0 - raise StopIteration - index = self._iter_index - self._iter_index += 1 - return {k: lst[index] for k, lst in self._store.items()} - - def __getitem__(self, index: int): - return {k: lst[index] for k, lst in self._store.items()} - - @property - def keys(self): - return self._keys - - @property - def capacity(self): - """Store capacity. - - If negative, the store grows without bound. Otherwise, the number of items in the store will not exceed - this capacity. - """ - return self._capacity - - @property - def overwrite_type(self): - """An ``OverwriteType`` member indicating the overwrite behavior when the store capacity is exceeded.""" - return self._overwrite_type - - def get(self, indexes: [int]) -> dict: - return {k: [self._store[k][i] for i in indexes] for k in self._store} - - def put(self, contents: Dict[str, List], overwrite_indexes: list = None) -> List[int]: - """Put new contents in the store. - - Args: - contents (dict): Dictionary of items to add to the store. If the store is not empty, this must have the - same keys as the store itself. Otherwise an ``StoreMisalignment`` will be raised. - overwrite_indexes (list, optional): Indexes where the contents are to be overwritten. This is only - used when the store has a fixed capacity and putting ``contents`` in the store would exceed this - capacity. If this is None and overwriting is necessary, rolling or random overwriting will be done - according to the ``overwrite`` property. Defaults to None. - Returns: - The indexes where the newly added entries reside in the store. - """ - if len(self._store) > 0 and list(contents.keys()) != self._keys: - raise StoreMisalignment(f"expected keys {self._keys}, got {list(contents.keys())}") - self.validate(contents) - added = contents[next(iter(contents))] - added_size = len(added) if isinstance(added, list) else 1 - if self._capacity < 0: - for key, val in contents.items(): - self._store[key].extend(val) - self._size += added_size - return list(range(self._size - added_size, self._size)) - else: - write_indexes = self._get_update_indexes(added_size, overwrite_indexes=overwrite_indexes) - self.update(write_indexes, contents) - self._size = min(self._capacity, self._size + added_size) - return write_indexes - - def update(self, indexes: list, contents: Dict[str, List]): - """ - Update contents at given positions. - - Args: - indexes (list): Positions where updates are to be made. - contents (dict): Contents to write to the internal store at given positions. It is subject to - uniformity checks to ensure that all values have the same length. - - Returns: - The indexes where store contents are updated. - """ - self.validate(contents) - for key, val in contents.items(): - for index, value in zip(indexes, val): - self._store[key][index] = value - - return indexes - - def apply_multi_filters(self, filters: List[Callable]): - """Multi-filter method. - - The input to one filter is the output from its predecessor in the sequence. - - Args: - filters (List[Callable]): Filter list, each item is a lambda function, - e.g., [lambda d: d['a'] == 1 and d['b'] == 1]. - Returns: - Filtered indexes and corresponding objects. - """ - indexes = range(self._size) - for f in filters: - indexes = [i for i in indexes if f(self[i])] - - return indexes, self.get(indexes) - - def apply_multi_samplers(self, samplers: list, replace: bool = True) -> Tuple: - """Multi-samplers method. - - This implements chained sampling where the input to one sampler is the output from its predecessor in - the sequence. - - Args: - samplers (list): A sequence of weight functions for computing the sampling weights of the items - in the store, - e.g., [lambda d: d['a'], lambda d: d['b']]. - replace (bool): If True, sampling will be performed with replacement. - Returns: - Sampled indexes and corresponding objects. - """ - indexes = range(self._size) - for weight_fn, sample_size in samplers: - weights = np.asarray([weight_fn(self[i]) for i in indexes]) - indexes = np.random.choice(indexes, size=sample_size, replace=replace, p=weights / np.sum(weights)) - - return indexes, self.get(indexes) - - def sample(self, size, weights: Union[list, np.ndarray] = None, replace: bool = True): - """ - Obtain a random sample from the experience pool. - - Args: - size (int): Sample sizes for each round of sampling in the chain. If this is a single integer, it is - used as the sample size for all samplers in the chain. - weights (Union[list, np.ndarray]): Sampling weights. - replace (bool): If True, sampling is performed with replacement. Defaults to True. - Returns: - Sampled indexes and the corresponding objects, - e.g., [1, 2, 3], ['a', 'b', 'c']. - """ - if weights is not None: - weights = np.asarray(weights) - weights = weights / np.sum(weights) - indexes = np.random.choice(self._size, size=size, replace=replace, p=weights) - return indexes, self.get(indexes) - - def sample_by_key(self, key, size: int, replace: bool = True): - """ - Obtain a random sample from the store using one of the columns as sampling weights. - - Args: - key: The column whose values are to be used as sampling weights. - size (int): Sample size. - replace (bool): If True, sampling is performed with replacement. - Returns: - Sampled indexes and the corresponding objects. - """ - weights = np.asarray(self._store[key][:self._size] if self._size < self._capacity else self._store[key]) - indexes = np.random.choice(self._size, size=size, replace=replace, p=weights / np.sum(weights)) - return indexes, self.get(indexes) - - def sample_by_keys(self, keys: list, sizes: list, replace: bool = True): - """ - Obtain a random sample from the store by chained sampling using multiple columns as sampling weights. - - Args: - keys (list): The column whose values are to be used as sampling weights. - sizes (list): Sample size. - replace (bool): If True, sampling is performed with replacement. - Returns: - Sampled indexes and the corresponding objects. - """ - if len(keys) != len(sizes): - raise ValueError(f"expected sizes of length {len(keys)}, got {len(sizes)}") - - indexes = range(self._size) - for key, size in zip(keys, sizes): - weights = np.asarray([self._store[key][i] for i in indexes]) - indexes = np.random.choice(indexes, size=size, replace=replace, p=weights / np.sum(weights)) - - return indexes, self.get(indexes) - - def clear(self): - """Empty the store.""" - self._store = {key: [] if self._capacity < 0 else [None] * self._capacity for key in self._keys} - self._size = 0 - self._iter_index = 0 - - def dumps(self): - """Return a deep copy of store contents.""" - return clone(dict(self._store)) - - def get_by_key(self, key): - """Get the contents of the store corresponding to ``key``.""" - return self._store[key] - - def _get_update_indexes(self, added_size: int, overwrite_indexes=None): - if added_size > self._capacity: - raise ValueError("size of added items should not exceed the store capacity.") - - num_overwrites = self._size + added_size - self._capacity - if num_overwrites < 0: - return list(range(self._size, self._size + added_size)) - - if overwrite_indexes is not None: - write_indexes = list(range(self._size, self._capacity)) + list(overwrite_indexes) - else: - # follow the overwrite rule set at init - if self._overwrite_type == OverwriteType.ROLLING: - # using the negative index convention for convenience - start_index = self._size - self._capacity - write_indexes = list(range(start_index, start_index + added_size)) - else: - random_indexes = np.random.choice(self._size, size=num_overwrites, replace=False) - write_indexes = list(range(self._size, self._capacity)) + list(random_indexes) - - return write_indexes - - @staticmethod - def validate(contents: Dict[str, List]): - # Ensure that all values are lists of the same length. - if any(not isinstance(val, list) for val in contents.values()): - raise TypeError("All values must be of type 'list'") - - reference_val = contents[list(contents.keys())[0]] - if any(len(val) != len(reference_val) for val in contents.values()): - raise StoreMisalignment("values of contents should consist of lists of the same length") diff --git a/maro/rl/training/__init__.py b/maro/rl/training/__init__.py index 4bd6269b3..d08313e23 100644 --- a/maro/rl/training/__init__.py +++ b/maro/rl/training/__init__.py @@ -1,9 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from .actor import Actor -from .actor_proxy import ActorProxy -from .learner import AbsLearner, OffPolicyLearner, OnPolicyLearner -from .trajectory import Trajectory +from .proxy import TrainingProxy +from .replay_memory import FIFOMultiReplayMemory, FIFOReplayMemory, RandomMultiReplayMemory, RandomReplayMemory +from .train_ops import AbsTrainOps, remote, RemoteOps +from .trainer import AbsTrainer, MultiAgentTrainer, SingleAgentTrainer, TrainerParams +from .training_manager import TrainingManager +from .worker import TrainOpsWorker -__all__ = ["AbsLearner", "Actor", "ActorProxy", "OffPolicyLearner", "OnPolicyLearner", "Trajectory"] +__all__ = [ + "TrainingProxy", + "FIFOMultiReplayMemory", "FIFOReplayMemory", "RandomMultiReplayMemory", "RandomReplayMemory", + "AbsTrainOps", "RemoteOps", "remote", + "AbsTrainer", "MultiAgentTrainer", "SingleAgentTrainer", "TrainerParams", + "TrainingManager", + "TrainOpsWorker", +] diff --git a/maro/rl/training/actor.py b/maro/rl/training/actor.py deleted file mode 100644 index d05cdd554..000000000 --- a/maro/rl/training/actor.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import sys -from os import getcwd -from typing import Union - -from maro.communication import Message, Proxy -from maro.rl.agent import AbsAgent, MultiAgentWrapper -from maro.simulator import Env -from maro.utils import Logger - -from .message_enums import MessageTag, PayloadKey - - -class Actor(object): - """Actor class that performs roll-out tasks. - - Args: - env (Env): An environment instance. - agent (Union[AbsAgent, MultiAgentWrapper]): Agent that interacts with the environment. - mode (str): One of "local" and "distributed". Defaults to "local". - """ - def __init__( - self, - env: Env, - agent: Union[AbsAgent, MultiAgentWrapper], - trajectory_cls, - trajectory_kwargs: dict = None - ): - super().__init__() - self.env = env - self.agent = agent - if trajectory_kwargs is None: - trajectory_kwargs = {} - self.trajectory = trajectory_cls(self.env, **trajectory_kwargs) - - def roll_out(self, index: int, training: bool = True, model_by_agent: dict = None, exploration_params=None): - """Perform one episode of roll-out. - Args: - index (int): Externally designated index to identify the roll-out round. - training (bool): If true, the roll-out is for training purposes, which usually means - some kind of training data, e.g., experiences, needs to be collected. Defaults to True. - model_by_agent (dict): Models to use for inference. Defaults to None. - exploration_params: Exploration parameters to use for the current roll-out. Defaults to None. - Returns: - Data collected during the episode. - """ - self.env.reset() - self.trajectory.reset() - if model_by_agent: - self.agent.load_model(model_by_agent) - if exploration_params: - self.agent.set_exploration_params(exploration_params) - - _, event, is_done = self.env.step(None) - while not is_done: - state_by_agent = self.trajectory.get_state(event) - action_by_agent = self.agent.choose_action(state_by_agent) - env_action = self.trajectory.get_action(action_by_agent, event) - if len(env_action) == 1: - env_action = list(env_action.values())[0] - _, next_event, is_done = self.env.step(env_action) - reward = self.trajectory.get_reward() - self.trajectory.on_env_feedback( - event, state_by_agent, action_by_agent, reward if reward is not None else self.env.metrics - ) - event = next_event - - return self.env.metrics, self.trajectory.on_finish() if training else None - - def as_worker(self, group: str, proxy_options=None, log_dir: str = getcwd()): - """Executes an event loop where roll-outs are performed on demand from a remote learner. - - Args: - group (str): Identifier of the group to which the actor belongs. It must be the same group name - assigned to the learner (and decision clients, if any). - proxy_options (dict): Keyword parameters for the internal ``Proxy`` instance. See ``Proxy`` class - for details. Defaults to None. - """ - if proxy_options is None: - proxy_options = {} - proxy = Proxy(group, "actor", {"learner": 1}, **proxy_options) - logger = Logger(proxy.name, dump_folder=log_dir) - for msg in proxy.receive(): - if msg.tag == MessageTag.EXIT: - logger.info("Exiting...") - proxy.close() - sys.exit(0) - elif msg.tag == MessageTag.ROLLOUT: - ep = msg.payload[PayloadKey.ROLLOUT_INDEX] - logger.info(f"Rolling out ({ep})...") - metrics, rollout_data = self.roll_out( - ep, - training=msg.payload[PayloadKey.TRAINING], - model_by_agent=msg.payload[PayloadKey.MODEL], - exploration_params=msg.payload[PayloadKey.EXPLORATION_PARAMS] - ) - if rollout_data is None: - logger.info(f"Roll-out {ep} aborted") - else: - logger.info(f"Roll-out {ep} finished") - rollout_finish_msg = Message( - MessageTag.FINISHED, - proxy.name, - proxy.peers_name["learner"][0], - payload={ - PayloadKey.ROLLOUT_INDEX: ep, - PayloadKey.METRICS: metrics, - PayloadKey.DETAILS: rollout_data - } - ) - proxy.isend(rollout_finish_msg) - self.env.reset() diff --git a/maro/rl/training/actor_proxy.py b/maro/rl/training/actor_proxy.py deleted file mode 100644 index 3571d69c4..000000000 --- a/maro/rl/training/actor_proxy.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from os import getcwd -from typing import List - -from maro.communication import Message, Proxy, RegisterTable, SessionType -from maro.utils import Logger - -from .message_enums import MessageTag, PayloadKey - - -class ActorProxy(object): - """Actor proxy that manages a set of remote actors. - - Args: - group_name (str): Identifier of the group to which the actor belongs. It must be the same group name - assigned to the actors (and roll-out clients, if any). - num_actors (int): Expected number of actors in the group identified by ``group_name``. - update_trigger (str): Number or percentage of ``MessageTag.FINISHED`` messages required to trigger - learner updates, i.e., model training. - proxy_options (dict): Keyword parameters for the internal ``Proxy`` instance. See ``Proxy`` class - for details. Defaults to None. - """ - def __init__( - self, - group_name: str, - num_actors: int, - update_trigger: str = None, - proxy_options: dict = None, - log_dir: str = getcwd() - ): - self.agent = None - peers = {"actor": num_actors} - if proxy_options is None: - proxy_options = {} - self._proxy = Proxy(group_name, "learner", peers, **proxy_options) - self._actors = self._proxy.peers_name["actor"] # remote actor ID's - self._registry_table = RegisterTable(self._proxy.peers_name) - if update_trigger is None: - update_trigger = len(self._actors) - self._registry_table.register_event_handler( - f"actor:{MessageTag.FINISHED.value}:{update_trigger}", self._on_rollout_finish - ) - self.logger = Logger("ACTOR_PROXY", dump_folder=log_dir) - - def roll_out(self, index: int, training: bool = True, model_by_agent: dict = None, exploration_params=None): - """Collect roll-out data from remote actors. - - Args: - index (int): Index of roll-out requests. - training (bool): If true, the roll-out request is for training purposes. - model_by_agent (dict): Models to be broadcast to remote actors for inference. Defaults to None. - exploration_params: Exploration parameters to be used by the remote roll-out actors. Defaults to None. - """ - payload = { - PayloadKey.ROLLOUT_INDEX: index, - PayloadKey.TRAINING: training, - PayloadKey.MODEL: model_by_agent, - PayloadKey.EXPLORATION_PARAMS: exploration_params - } - self._proxy.iscatter(MessageTag.ROLLOUT, SessionType.TASK, [(actor, payload) for actor in self._actors]) - self.logger.info(f"Sent roll-out requests to {self._actors} for ep-{index}") - - # Receive roll-out results from remote actors - for msg in self._proxy.receive(): - if msg.payload[PayloadKey.ROLLOUT_INDEX] != index: - self.logger.info( - f"Ignore a message of type {msg.tag} with ep {msg.payload[PayloadKey.ROLLOUT_INDEX]} " - f"(expected {index} or greater)" - ) - continue - if msg.tag == MessageTag.FINISHED: - # If enough update messages have been received, call update() and break out of the loop to start - # the next episode. - result = self._registry_table.push(msg) - if result: - env_metrics, details = result[0] - break - - return env_metrics, details - - def _on_rollout_finish(self, messages: List[Message]): - metrics = {msg.source: msg.payload[PayloadKey.METRICS] for msg in messages} - details = {msg.source: msg.payload[PayloadKey.DETAILS] for msg in messages} - return metrics, details - - def terminate(self): - """Tell the remote actors to exit.""" - self._proxy.ibroadcast( - component_type="actor", tag=MessageTag.EXIT, session_type=SessionType.NOTIFICATION - ) - self.logger.info("Exiting...") - self._proxy.close() diff --git a/maro/rl/training/algorithms/__init__.py b/maro/rl/training/algorithms/__init__.py new file mode 100644 index 000000000..3bcf80c99 --- /dev/null +++ b/maro/rl/training/algorithms/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from .ac import ActorCriticParams, ActorCriticTrainer +from .ddpg import DDPGParams, DDPGTrainer +from .dqn import DQNParams, DQNTrainer +from .maddpg import DiscreteMADDPGParams, DiscreteMADDPGTrainer +from .ppo import PPOParams, PPOTrainer, DiscretePPOWithEntropyTrainer +from .sac import SoftActorCriticParams, SoftActorCriticTrainer + +__all__ = [ + "ActorCriticTrainer", "ActorCriticParams", + "DDPGTrainer", "DDPGParams", + "DQNTrainer", "DQNParams", + "DiscreteMADDPGTrainer", "DiscreteMADDPGParams", + "PPOParams", "PPOTrainer", "DiscretePPOWithEntropyTrainer", + "SoftActorCriticParams", "SoftActorCriticTrainer", +] diff --git a/maro/rl/training/algorithms/ac.py b/maro/rl/training/algorithms/ac.py new file mode 100644 index 000000000..2f9d576e2 --- /dev/null +++ b/maro/rl/training/algorithms/ac.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from dataclasses import dataclass +from typing import Dict + +from maro.rl.training.algorithms.base import ACBasedParams, ACBasedTrainer + + +@dataclass +class ActorCriticParams(ACBasedParams): + """Identical to `ACBasedParams`. Please refer to the doc string of `ACBasedParams` + for detailed information. + """ + + def extract_ops_params(self) -> Dict[str, object]: + return { + "get_v_critic_net_func": self.get_v_critic_net_func, + "reward_discount": self.reward_discount, + "critic_loss_cls": self.critic_loss_cls, + "lam": self.lam, + "min_logp": self.min_logp, + "is_discrete_action": self.is_discrete_action, + } + + def __post_init__(self) -> None: + assert self.get_v_critic_net_func is not None + + +class ActorCriticTrainer(ACBasedTrainer): + """Actor-Critic algorithm with separate policy and value models. + + Reference: + https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/vpg + """ + + def __init__(self, name: str, params: ActorCriticParams) -> None: + super(ActorCriticTrainer, self).__init__(name, params) diff --git a/maro/rl/training/algorithms/base/__init__.py b/maro/rl/training/algorithms/base/__init__.py new file mode 100644 index 000000000..857601d2a --- /dev/null +++ b/maro/rl/training/algorithms/base/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from .ac_ppo_base import ACBasedOps, ACBasedParams, ACBasedTrainer + +__all__ = ["ACBasedOps", "ACBasedParams", "ACBasedTrainer"] diff --git a/maro/rl/training/algorithms/base/ac_ppo_base.py b/maro/rl/training/algorithms/base/ac_ppo_base.py new file mode 100644 index 000000000..5c93e9aec --- /dev/null +++ b/maro/rl/training/algorithms/base/ac_ppo_base.py @@ -0,0 +1,301 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import ABCMeta +from dataclasses import dataclass +from typing import Callable, Dict, Optional, Tuple + +import numpy as np +import torch + +from maro.rl.model import VNet +from maro.rl.policy import ContinuousRLPolicy, DiscretePolicyGradient, RLPolicy +from maro.rl.training import AbsTrainOps, FIFOReplayMemory, remote, RemoteOps, SingleAgentTrainer, TrainerParams +from maro.rl.utils import discount_cumsum, get_torch_device, ndarray_to_tensor, TransitionBatch + + +@dataclass +class ACBasedParams(TrainerParams, metaclass=ABCMeta): + """ + Parameter bundle for Actor-Critic based algorithms (Actor-Critic & PPO) + + get_v_critic_net_func (Callable[[], VNet]): Function to get V critic net. + grad_iters (int, default=1): Number of iterations to calculate gradients. + critic_loss_cls (Callable, default=None): Critic loss function. If it is None, use MSE. + lam (float, default=0.9): Lambda value for generalized advantage estimation (TD-Lambda). + min_logp (float, default=None): Lower bound for clamping logP values during learning. + This is to prevent logP from becoming very large in magnitude and causing stability issues. + If it is None, it means no lower bound. + is_discrete_action (bool, default=True): Indicator of continuous or discrete action policy. + """ + get_v_critic_net_func: Callable[[], VNet] = None + grad_iters: int = 1 + critic_loss_cls: Callable = None + lam: float = 0.9 + min_logp: Optional[float] = None + is_discrete_action: bool = True + + +class ACBasedOps(AbsTrainOps): + """Base class of Actor-Critic algorithm implementation. Reference: https://tinyurl.com/2ezte4cr + """ + + def __init__( + self, + name: str, + policy_creator: Callable[[], RLPolicy], + get_v_critic_net_func: Callable[[], VNet], + parallelism: int = 1, + reward_discount: float = 0.9, + critic_loss_cls: Callable = None, + clip_ratio: float = None, + lam: float = 0.9, + min_logp: float = None, + is_discrete_action: bool = True, + ) -> None: + super(ACBasedOps, self).__init__( + name=name, + policy_creator=policy_creator, + parallelism=parallelism, + ) + + assert isinstance(self._policy, DiscretePolicyGradient) or isinstance(self._policy, ContinuousRLPolicy) + + self._reward_discount = reward_discount + self._critic_loss_func = critic_loss_cls() if critic_loss_cls is not None else torch.nn.MSELoss() + self._clip_ratio = clip_ratio + self._lam = lam + self._min_logp = min_logp + self._v_critic_net = get_v_critic_net_func() + self._is_discrete_action = is_discrete_action + + self._device = None + + def _get_critic_loss(self, batch: TransitionBatch) -> torch.Tensor: + """Compute the critic loss of the batch. + + Args: + batch (TransitionBatch): Batch. + + Returns: + loss (torch.Tensor): The critic loss of the batch. + """ + states = ndarray_to_tensor(batch.states, device=self._device) + returns = ndarray_to_tensor(batch.returns, device=self._device) + + self._v_critic_net.train() + state_values = self._v_critic_net.v_values(states) + return self._critic_loss_func(state_values, returns) + + @remote + def get_critic_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]: + """Compute the critic network's gradients of a batch. + + Args: + batch (TransitionBatch): Batch. + + Returns: + grad (torch.Tensor): The critic gradient of the batch. + """ + return self._v_critic_net.get_gradients(self._get_critic_loss(batch)) + + def update_critic(self, batch: TransitionBatch) -> None: + """Update the critic network using a batch. + + Args: + batch (TransitionBatch): Batch. + """ + self._v_critic_net.step(self._get_critic_loss(batch)) + + def update_critic_with_grad(self, grad_dict: dict) -> None: + """Update the critic network with remotely computed gradients. + + Args: + grad_dict (dict): Gradients. + """ + self._v_critic_net.train() + self._v_critic_net.apply_gradients(grad_dict) + + def _get_actor_loss(self, batch: TransitionBatch) -> Tuple[torch.Tensor, bool]: + """Compute the actor loss of the batch. + + Args: + batch (TransitionBatch): Batch. + + Returns: + loss (torch.Tensor): The actor loss of the batch. + early_stop (bool): Early stop indicator. + """ + assert isinstance(self._policy, DiscretePolicyGradient) or isinstance(self._policy, ContinuousRLPolicy) + self._policy.train() + + states = ndarray_to_tensor(batch.states, device=self._device) + actions = ndarray_to_tensor(batch.actions, device=self._device) + advantages = ndarray_to_tensor(batch.advantages, device=self._device) + logps_old = ndarray_to_tensor(batch.old_logps, device=self._device) + if self._is_discrete_action: + actions = actions.long() + + logps = self._policy.get_states_actions_logps(states, actions) + if self._clip_ratio is not None: + ratio = torch.exp(logps - logps_old) + kl = (logps_old - logps).mean().item() + early_stop = (kl >= 0.01 * 1.5) # TODO + clipped_ratio = torch.clamp(ratio, 1 - self._clip_ratio, 1 + self._clip_ratio) + actor_loss = -(torch.min(ratio * advantages, clipped_ratio * advantages)).mean() + else: + actor_loss = -(logps * advantages).mean() # I * delta * log pi(a|s) + early_stop = False + + return actor_loss, early_stop + + @remote + def get_actor_grad(self, batch: TransitionBatch) -> Tuple[Dict[str, torch.Tensor], bool]: + """Compute the actor network's gradients of a batch. + + Args: + batch (TransitionBatch): Batch. + + Returns: + grad (torch.Tensor): The actor gradient of the batch. + early_stop (bool): Early stop indicator. + """ + loss, early_stop = self._get_actor_loss(batch) + return self._policy.get_gradients(loss), early_stop + + def update_actor(self, batch: TransitionBatch) -> bool: + """Update the actor network using a batch. + + Args: + batch (TransitionBatch): Batch. + + Returns: + early_stop (bool): Early stop indicator. + """ + loss, early_stop = self._get_actor_loss(batch) + self._policy.train_step(loss) + return early_stop + + def update_actor_with_grad(self, grad_dict_and_early_stop: Tuple[dict, bool]) -> bool: + """Update the actor network with remotely computed gradients. + + Args: + grad_dict_and_early_stop (Tuple[dict, bool]): Gradients and early stop indicator. + + Returns: + early stop indicator + """ + self._policy.train() + self._policy.apply_gradients(grad_dict_and_early_stop[0]) + return grad_dict_and_early_stop[1] + + def get_non_policy_state(self) -> dict: + return { + "critic": self._v_critic_net.get_state(), + } + + def set_non_policy_state(self, state: dict) -> None: + self._v_critic_net.set_state(state["critic"]) + + def preprocess_batch(self, batch: TransitionBatch) -> TransitionBatch: + """Preprocess the batch to get the returns & advantages. + + Args: + batch (TransitionBatch): Batch. + + Returns: + The updated batch. + """ + assert isinstance(batch, TransitionBatch) + + # Preprocess advantages + states = ndarray_to_tensor(batch.states, device=self._device) # s + actions = ndarray_to_tensor(batch.actions, device=self._device) # a + if self._is_discrete_action: + actions = actions.long() + + with torch.no_grad(): + self._v_critic_net.eval() + self._policy.eval() + values = self._v_critic_net.v_values(states).detach().cpu().numpy() + values = np.concatenate([values, np.zeros(1)]) + rewards = np.concatenate([batch.rewards, np.zeros(1)]) + deltas = rewards[:-1] + self._reward_discount * values[1:] - values[:-1] # r + gamma * v(s') - v(s) + batch.returns = discount_cumsum(rewards, self._reward_discount)[:-1] + batch.advantages = discount_cumsum(deltas, self._reward_discount * self._lam) + + if self._clip_ratio is not None: + batch.old_logps = self._policy.get_states_actions_logps(states, actions).detach().cpu().numpy() + + return batch + + def debug_get_v_values(self, batch: TransitionBatch) -> np.ndarray: + states = ndarray_to_tensor(batch.states, device=self._device) # s + with torch.no_grad(): + values = self._v_critic_net.v_values(states).detach().cpu().numpy() + return values + + def to_device(self, device: str = None) -> None: + self._device = get_torch_device(device) + self._policy.to_device(self._device) + self._v_critic_net.to(self._device) + + +class ACBasedTrainer(SingleAgentTrainer): + """Base class of Actor-Critic algorithm implementation. + + References: + https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch + https://towardsdatascience.com/understanding-actor-critic-methods-931b97b6df3f + """ + + def __init__(self, name: str, params: ACBasedParams) -> None: + super(ACBasedTrainer, self).__init__(name, params) + self._params = params + + def build(self) -> None: + self._ops = self.get_ops() + self._replay_memory = FIFOReplayMemory( + capacity=self._params.replay_memory_capacity, + state_dim=self._ops.policy_state_dim, + action_dim=self._ops.policy_action_dim, + ) + + def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch: + return self._ops.preprocess_batch(transition_batch) + + def get_local_ops(self) -> AbsTrainOps: + return ACBasedOps( + name=self._policy_name, + policy_creator=self._policy_creator, + parallelism=self._params.data_parallelism, + **self._params.extract_ops_params(), + ) + + def _get_batch(self) -> TransitionBatch: + batch = self._replay_memory.sample(-1) + batch.advantages = (batch.advantages - batch.advantages.mean()) / batch.advantages.std() + return batch + + def train_step(self) -> None: + assert isinstance(self._ops, ACBasedOps) + + batch = self._get_batch() + for _ in range(self._params.grad_iters): + self._ops.update_critic(batch) + + for _ in range(self._params.grad_iters): + early_stop = self._ops.update_actor(batch) + if early_stop: + break + + async def train_step_as_task(self) -> None: + assert isinstance(self._ops, RemoteOps) + + batch = self._get_batch() + for _ in range(self._params.grad_iters): + self._ops.update_critic_with_grad(await self._ops.get_critic_grad(batch)) + + for _ in range(self._params.grad_iters): + if self._ops.update_actor_with_grad(await self._ops.get_actor_grad(batch)): # early stop + break diff --git a/maro/rl/training/algorithms/ddpg.py b/maro/rl/training/algorithms/ddpg.py new file mode 100644 index 000000000..601aed6e0 --- /dev/null +++ b/maro/rl/training/algorithms/ddpg.py @@ -0,0 +1,296 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from dataclasses import dataclass +from typing import Callable, Dict + +import torch + +from maro.rl.model import QNet +from maro.rl.policy import ContinuousRLPolicy, RLPolicy +from maro.rl.training import AbsTrainOps, RandomReplayMemory, remote, RemoteOps, SingleAgentTrainer, TrainerParams +from maro.rl.utils import get_torch_device, ndarray_to_tensor, TransitionBatch +from maro.utils import clone + + +@dataclass +class DDPGParams(TrainerParams): + """ + get_q_critic_net_func (Callable[[], QNet]): Function to get Q critic net. + num_epochs (int, default=1): Number of training epochs per call to ``learn``. + update_target_every (int, default=5): Number of training rounds between policy target model updates. + q_value_loss_cls (str, default=None): A string indicating a loss class provided by torch.nn or a custom + loss class for the Q-value loss. If it is a string, it must be a key in ``TORCH_LOSS``. + If it is None, use MSE. + soft_update_coef (float, default=1.0): Soft update coefficient, e.g., + target_model = (soft_update_coef) * eval_model + (1-soft_update_coef) * target_model. + random_overwrite (bool, default=False): This specifies overwrite behavior when the replay memory capacity + is reached. If True, overwrite positions will be selected randomly. Otherwise, overwrites will occur + sequentially with wrap-around. + min_num_to_trigger_training (int, default=0): Minimum number required to start training. + """ + get_q_critic_net_func: Callable[[], QNet] = None + num_epochs: int = 1 + update_target_every: int = 5 + q_value_loss_cls: Callable = None + soft_update_coef: float = 1.0 + random_overwrite: bool = False + min_num_to_trigger_training: int = 0 + + def __post_init__(self) -> None: + assert self.get_q_critic_net_func is not None + + def extract_ops_params(self) -> Dict[str, object]: + return { + "get_q_critic_net_func": self.get_q_critic_net_func, + "reward_discount": self.reward_discount, + "q_value_loss_cls": self.q_value_loss_cls, + "soft_update_coef": self.soft_update_coef, + } + + +class DDPGOps(AbsTrainOps): + """DDPG algorithm implementation. Reference: https://spinningup.openai.com/en/latest/algorithms/ddpg.html + """ + + def __init__( + self, + name: str, + policy_creator: Callable[[], RLPolicy], + get_q_critic_net_func: Callable[[], QNet], + reward_discount: float, + parallelism: int = 1, + q_value_loss_cls: Callable = None, + soft_update_coef: float = 1.0, + ) -> None: + super(DDPGOps, self).__init__( + name=name, + policy_creator=policy_creator, + parallelism=parallelism, + ) + + assert isinstance(self._policy, ContinuousRLPolicy) + + self._target_policy = clone(self._policy) + self._target_policy.set_name(f"target_{self._policy.name}") + self._target_policy.eval() + self._q_critic_net = get_q_critic_net_func() + self._target_q_critic_net: QNet = clone(self._q_critic_net) + self._target_q_critic_net.eval() + + self._reward_discount = reward_discount + self._q_value_loss_func = q_value_loss_cls() if q_value_loss_cls is not None else torch.nn.MSELoss() + self._soft_update_coef = soft_update_coef + + def _get_critic_loss(self, batch: TransitionBatch) -> torch.Tensor: + """Compute the critic loss of the batch. + + Args: + batch (TransitionBatch): Batch. + + Returns: + loss (torch.Tensor): The critic loss of the batch. + """ + assert isinstance(batch, TransitionBatch) + self._q_critic_net.train() + states = ndarray_to_tensor(batch.states, device=self._device) # s + next_states = ndarray_to_tensor(batch.next_states, device=self._device) # s' + actions = ndarray_to_tensor(batch.actions, device=self._device) # a + rewards = ndarray_to_tensor(batch.rewards, device=self._device) # r + terminals = ndarray_to_tensor(batch.terminals, device=self._device) # d + + with torch.no_grad(): + next_q_values = self._target_q_critic_net.q_values( + states=next_states, # s' + actions=self._target_policy.get_actions_tensor(next_states), # miu_targ(s') + ) # Q_targ(s', miu_targ(s')) + + # y(r, s', d) = r + gamma * (1 - d) * Q_targ(s', miu_targ(s')) + target_q_values = (rewards + self._reward_discount * (1 - terminals.long()) * next_q_values).detach() + q_values = self._q_critic_net.q_values(states=states, actions=actions) # Q(s, a) + return self._q_value_loss_func(q_values, target_q_values) # MSE(Q(s, a), y(r, s', d)) + + @remote + def get_critic_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]: + """Compute the critic network's gradients of a batch. + + Args: + batch (TransitionBatch): Batch. + + Returns: + grad (torch.Tensor): The critic gradient of the batch. + """ + return self._q_critic_net.get_gradients(self._get_critic_loss(batch)) + + def update_critic_with_grad(self, grad_dict: dict) -> None: + """Update the critic network with remotely computed gradients. + + Args: + grad_dict (dict): Gradients. + """ + self._q_critic_net.train() + self._q_critic_net.apply_gradients(grad_dict) + + def update_critic(self, batch: TransitionBatch) -> None: + """Update the critic network using a batch. + + Args: + batch (TransitionBatch): Batch. + """ + self._q_critic_net.train() + self._q_critic_net.step(self._get_critic_loss(batch)) + + def _get_actor_loss(self, batch: TransitionBatch) -> torch.Tensor: + """Compute the actor loss of the batch. + + Args: + batch (TransitionBatch): Batch. + + Returns: + loss (torch.Tensor): The actor loss of the batch. + """ + assert isinstance(batch, TransitionBatch) + self._policy.train() + states = ndarray_to_tensor(batch.states, device=self._device) # s + + policy_loss = -self._q_critic_net.q_values( + states=states, # s + actions=self._policy.get_actions_tensor(states), # miu(s) + ).mean() # -Q(s, miu(s)) + + return policy_loss + + @remote + def get_actor_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]: + """Compute the actor network's gradients of a batch. + + Args: + batch (TransitionBatch): Batch. + + Returns: + grad (torch.Tensor): The actor gradient of the batch. + """ + return self._policy.get_gradients(self._get_actor_loss(batch)) + + def update_actor_with_grad(self, grad_dict: dict) -> None: + """Update the actor network with remotely computed gradients. + + Args: + grad_dict (dict): Gradients. + """ + self._policy.train() + self._policy.apply_gradients(grad_dict) + + def update_actor(self, batch: TransitionBatch) -> None: + """Update the actor network using a batch. + + Args: + batch (TransitionBatch): Batch. + """ + self._policy.train() + self._policy.train_step(self._get_actor_loss(batch)) + + def get_non_policy_state(self) -> dict: + return { + "target_policy": self._target_policy.get_state(), + "critic": self._q_critic_net.get_state(), + "target_critic": self._target_q_critic_net.get_state(), + } + + def set_non_policy_state(self, state: dict) -> None: + self._target_policy.set_state(state["target_policy"]) + self._q_critic_net.set_state(state["critic"]) + self._target_q_critic_net.set_state(state["target_critic"]) + + def soft_update_target(self) -> None: + """Soft update the target policy and target critic. + """ + self._target_policy.soft_update(self._policy, self._soft_update_coef) + self._target_q_critic_net.soft_update(self._q_critic_net, self._soft_update_coef) + + def to_device(self, device: str) -> None: + self._device = get_torch_device(device=device) + self._policy.to_device(self._device) + self._target_policy.to_device(self._device) + self._q_critic_net.to(self._device) + self._target_q_critic_net.to(self._device) + + +class DDPGTrainer(SingleAgentTrainer): + """The Deep Deterministic Policy Gradient (DDPG) algorithm. + + References: + https://arxiv.org/pdf/1509.02971.pdf + https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/ddpg + """ + + def __init__(self, name: str, params: DDPGParams) -> None: + super(DDPGTrainer, self).__init__(name, params) + self._params = params + self._policy_version = self._target_policy_version = 0 + self._memory_size = 0 + + def build(self) -> None: + self._ops = self.get_ops() + self._replay_memory = RandomReplayMemory( + capacity=self._params.replay_memory_capacity, + state_dim=self._ops.policy_state_dim, + action_dim=self._ops.policy_action_dim, + random_overwrite=self._params.random_overwrite, + ) + + def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch: + return transition_batch + + def get_local_ops(self) -> AbsTrainOps: + return DDPGOps( + name=self._policy_name, + policy_creator=self._policy_creator, + parallelism=self._params.data_parallelism, + **self._params.extract_ops_params(), + ) + + def _get_batch(self, batch_size: int = None) -> TransitionBatch: + return self._replay_memory.sample(batch_size if batch_size is not None else self._batch_size) + + def train_step(self) -> None: + assert isinstance(self._ops, DDPGOps) + + if self._replay_memory.n_sample < self._params.min_num_to_trigger_training: + print( + f"Skip this training step due to lack of experiences " + f"(current = {self._replay_memory.n_sample}, minimum = {self._params.min_num_to_trigger_training})" + ) + return + + for _ in range(self._params.num_epochs): + batch = self._get_batch() + self._ops.update_critic(batch) + self._ops.update_actor(batch) + + self._try_soft_update_target() + + async def train_step_as_task(self) -> None: + assert isinstance(self._ops, RemoteOps) + + if self._replay_memory.n_sample < self._params.min_num_to_trigger_training: + print( + f"Skip this training step due to lack of experiences " + f"(current = {self._replay_memory.n_sample}, minimum = {self._params.min_num_to_trigger_training})" + ) + return + + for _ in range(self._params.num_epochs): + batch = self._get_batch() + self._ops.update_critic_with_grad(await self._ops.get_critic_grad(batch)) + self._ops.update_actor_with_grad(await self._ops.get_actor_grad(batch)) + + self._try_soft_update_target() + + def _try_soft_update_target(self) -> None: + """Soft update the target policy and target critic. + """ + self._policy_version += 1 + if self._policy_version - self._target_policy_version == self._params.update_target_every: + self._ops.soft_update_target() + self._target_policy_version = self._policy_version diff --git a/maro/rl/training/algorithms/dqn.py b/maro/rl/training/algorithms/dqn.py new file mode 100644 index 000000000..36a79dee3 --- /dev/null +++ b/maro/rl/training/algorithms/dqn.py @@ -0,0 +1,204 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from dataclasses import dataclass +from typing import Callable, Dict + +import torch + +from maro.rl.policy import RLPolicy, ValueBasedPolicy +from maro.rl.training import AbsTrainOps, RandomReplayMemory, remote, RemoteOps, SingleAgentTrainer, TrainerParams +from maro.rl.utils import get_torch_device, ndarray_to_tensor, TransitionBatch +from maro.utils import clone + + +@dataclass +class DQNParams(TrainerParams): + """ + num_epochs (int, default=1): Number of training epochs. + update_target_every (int, default=5): Number of gradient steps between target model updates. + soft_update_coef (float, default=0.1): Soft update coefficient, e.g., + target_model = (soft_update_coef) * eval_model + (1-soft_update_coef) * target_model. + double (bool, default=False): If True, the next Q values will be computed according to the double DQN algorithm, + i.e., q_next = Q_target(s, argmax(Q_eval(s, a))). Otherwise, q_next = max(Q_target(s, a)). + See https://arxiv.org/pdf/1509.06461.pdf for details. + random_overwrite (bool, default=False): This specifies overwrite behavior when the replay memory capacity + is reached. If True, overwrite positions will be selected randomly. Otherwise, overwrites will occur + sequentially with wrap-around. + """ + num_epochs: int = 1 + update_target_every: int = 5 + soft_update_coef: float = 0.1 + double: bool = False + random_overwrite: bool = False + + def extract_ops_params(self) -> Dict[str, object]: + return { + "reward_discount": self.reward_discount, + "soft_update_coef": self.soft_update_coef, + "double": self.double, + } + + +class DQNOps(AbsTrainOps): + def __init__( + self, + name: str, + policy_creator: Callable[[], RLPolicy], + parallelism: int = 1, + reward_discount: float = 0.9, + soft_update_coef: float = 0.1, + double: bool = False, + ) -> None: + super(DQNOps, self).__init__( + name=name, + policy_creator=policy_creator, + parallelism=parallelism, + ) + + assert isinstance(self._policy, ValueBasedPolicy) + + self._reward_discount = reward_discount + self._soft_update_coef = soft_update_coef + self._double = double + self._loss_func = torch.nn.MSELoss() + + self._target_policy: ValueBasedPolicy = clone(self._policy) + self._target_policy.set_name(f"target_{self._policy.name}") + self._target_policy.eval() + + def _get_batch_loss(self, batch: TransitionBatch) -> Dict[str, Dict[str, torch.Tensor]]: + """Compute the loss of the batch. + + Args: + batch (TransitionBatch): Batch. + + Returns: + loss (torch.Tensor): The loss of the batch. + """ + assert isinstance(batch, TransitionBatch) + self._policy.train() + states = ndarray_to_tensor(batch.states, device=self._device) + next_states = ndarray_to_tensor(batch.next_states, device=self._device) + actions = ndarray_to_tensor(batch.actions, device=self._device) + rewards = ndarray_to_tensor(batch.rewards, device=self._device) + terminals = ndarray_to_tensor(batch.terminals, device=self._device).float() + + with torch.no_grad(): + if self._double: + self._policy.exploit() + actions_by_eval_policy = self._policy.get_actions_tensor(next_states) + next_q_values = self._target_policy.q_values_tensor(next_states, actions_by_eval_policy) + else: + self._target_policy.exploit() + actions = self._target_policy.get_actions_tensor(next_states) + next_q_values = self._target_policy.q_values_tensor(next_states, actions) + + target_q_values = (rewards + self._reward_discount * (1 - terminals) * next_q_values).detach() + q_values = self._policy.q_values_tensor(states, actions) + return self._loss_func(q_values, target_q_values) + + @remote + def get_batch_grad(self, batch: TransitionBatch) -> Dict[str, Dict[str, torch.Tensor]]: + """Compute the network's gradients of a batch. + + Args: + batch (TransitionBatch): Batch. + + Returns: + grad (torch.Tensor): The gradient of the batch. + """ + return self._policy.get_gradients(self._get_batch_loss(batch)) + + def update_with_grad(self, grad_dict: dict) -> None: + """Update the network with remotely computed gradients. + + Args: + grad_dict (dict): Gradients. + """ + self._policy.train() + self._policy.apply_gradients(grad_dict) + + def update(self, batch: TransitionBatch) -> None: + """Update the network using a batch. + + Args: + batch (TransitionBatch): Batch. + """ + self._policy.train() + self._policy.train_step(self._get_batch_loss(batch)) + + def get_non_policy_state(self) -> dict: + return { + "target_q_net": self._target_policy.get_state(), + } + + def set_non_policy_state(self, state: dict) -> None: + self._target_policy.set_state(state["target_q_net"]) + + def soft_update_target(self) -> None: + """Soft update the target policy. + """ + self._target_policy.soft_update(self._policy, self._soft_update_coef) + + def to_device(self, device: str) -> None: + self._device = get_torch_device(device) + self._policy.to_device(self._device) + self._target_policy.to_device(self._device) + + +class DQNTrainer(SingleAgentTrainer): + """The Deep-Q-Networks algorithm. + + See https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf for details. + """ + + def __init__(self, name: str, params: DQNParams) -> None: + super(DQNTrainer, self).__init__(name, params) + self._params = params + self._q_net_version = self._target_q_net_version = 0 + + def build(self) -> None: + self._ops = self.get_ops() + self._replay_memory = RandomReplayMemory( + capacity=self._params.replay_memory_capacity, + state_dim=self._ops.policy_state_dim, + action_dim=self._ops.policy_action_dim, + random_overwrite=self._params.random_overwrite, + ) + + def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch: + return transition_batch + + def get_local_ops(self) -> AbsTrainOps: + return DQNOps( + name=self._policy_name, + policy_creator=self._policy_creator, + parallelism=self._params.data_parallelism, + **self._params.extract_ops_params(), + ) + + def _get_batch(self, batch_size: int = None) -> TransitionBatch: + return self._replay_memory.sample(batch_size if batch_size is not None else self._batch_size) + + def train_step(self) -> None: + assert isinstance(self._ops, DQNOps) + for _ in range(self._params.num_epochs): + self._ops.update(self._get_batch()) + + self._try_soft_update_target() + + async def train_step_as_task(self) -> None: + assert isinstance(self._ops, RemoteOps) + for _ in range(self._params.num_epochs): + batch = self._get_batch() + self._ops.update_with_grad(await self._ops.get_batch_grad(batch)) + + self._try_soft_update_target() + + def _try_soft_update_target(self) -> None: + """Soft update the target policy and target critic. + """ + self._q_net_version += 1 + if self._q_net_version - self._target_q_net_version == self._params.update_target_every: + self._ops.soft_update_target() + self._target_q_net_version = self._q_net_version diff --git a/maro/rl/training/algorithms/maddpg.py b/maro/rl/training/algorithms/maddpg.py new file mode 100644 index 000000000..2a6a2b688 --- /dev/null +++ b/maro/rl/training/algorithms/maddpg.py @@ -0,0 +1,496 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import asyncio +import os +from dataclasses import dataclass +from typing import Callable, Dict, List, Tuple + +import numpy as np +import torch + +from maro.rl.model import MultiQNet +from maro.rl.policy import DiscretePolicyGradient, RLPolicy +from maro.rl.rollout import ExpElement +from maro.rl.training import AbsTrainOps, MultiAgentTrainer, RandomMultiReplayMemory, remote, RemoteOps, TrainerParams +from maro.rl.utils import get_torch_device, MultiTransitionBatch, ndarray_to_tensor +from maro.rl.utils.objects import FILE_SUFFIX +from maro.utils import clone + + +@dataclass +class DiscreteMADDPGParams(TrainerParams): + """ + get_q_critic_net_func (Callable[[], MultiQNet]): Function to get multi Q critic net. + num_epochs (int, default=10): Number of training epochs. + update_target_every (int, default=5): Number of gradient steps between target model updates. + soft_update_coef (float, default=0.5): Soft update coefficient, e.g., + target_model = (soft_update_coef) * eval_model + (1-soft_update_coef) * target_model. + q_value_loss_cls (Callable, default=None): Critic loss function. If it is None, use MSE. + shared_critic (bool, default=False): Whether different policies use shared critic or individual policies. + """ + get_q_critic_net_func: Callable[[], MultiQNet] = None + num_epoch: int = 10 + update_target_every: int = 5 + soft_update_coef: float = 0.5 + q_value_loss_cls: Callable = None + shared_critic: bool = False + + def __post_init__(self) -> None: + assert self.get_q_critic_net_func is not None + + def extract_ops_params(self) -> Dict[str, object]: + return { + "get_q_critic_net_func": self.get_q_critic_net_func, + "shared_critic": self.shared_critic, + "reward_discount": self.reward_discount, + "soft_update_coef": self.soft_update_coef, + "update_target_every": self.update_target_every, + "q_value_loss_func": self.q_value_loss_cls() if self.q_value_loss_cls is not None else torch.nn.MSELoss(), + } + + +class DiscreteMADDPGOps(AbsTrainOps): + def __init__( + self, + name: str, + policy_creator: Callable[[], RLPolicy], + get_q_critic_net_func: Callable[[], MultiQNet], + policy_idx: int, + parallelism: int = 1, + shared_critic: bool = False, + reward_discount: float = 0.9, + soft_update_coef: float = 0.5, + update_target_every: int = 5, + q_value_loss_func: Callable = None, + ) -> None: + super(DiscreteMADDPGOps, self).__init__( + name=name, + policy_creator=policy_creator, + parallelism=parallelism + ) + + self._policy_idx = policy_idx + self._shared_critic = shared_critic + + # Actor + if self._policy_creator: + assert isinstance(self._policy, DiscretePolicyGradient) + self._target_policy: DiscretePolicyGradient = clone(self._policy) + self._target_policy.set_name(f"target_{self._policy.name}") + self._target_policy.eval() + + # Critic + self._q_critic_net: MultiQNet = get_q_critic_net_func() + self._target_q_critic_net: MultiQNet = clone(self._q_critic_net) + self._target_q_critic_net.eval() + + self._reward_discount = reward_discount + self._q_value_loss_func = q_value_loss_func + self._update_target_every = update_target_every + self._soft_update_coef = soft_update_coef + + self._device = None + + def get_target_action(self, batch: MultiTransitionBatch) -> torch.Tensor: + """Get the target policies' actions according to the batch. + + Args: + batch (MultiTransitionBatch): Batch. + + Returns: + actions (torch.Tensor): Target policies' actions. + """ + agent_state = ndarray_to_tensor(batch.agent_states[self._policy_idx], device=self._device) + return self._target_policy.get_actions_tensor(agent_state) + + def get_latest_action(self, batch: MultiTransitionBatch) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the latest actions and corresponding log-probabilities according to the batch. + + Args: + batch (MultiTransitionBatch): Batch. + + Returns: + actions (torch.Tensor): Target policies' actions. + logps (torch.Tensor): Log-probabilities. + """ + assert isinstance(self._policy, DiscretePolicyGradient) + + agent_state = ndarray_to_tensor(batch.agent_states[self._policy_idx], device=self._device) + self._policy.train() + action = self._policy.get_actions_tensor(agent_state) + logps = self._policy.get_states_actions_logps(agent_state, action) + return action, logps + + def _get_critic_loss(self, batch: MultiTransitionBatch, next_actions: List[torch.Tensor]) -> torch.Tensor: + """Compute the critic loss of the batch. + + Args: + batch (MultiTransitionBatch): Batch. + next_actions (List[torch.Tensor]): List of next actions of all policies. + + Returns: + loss (torch.Tensor): The critic loss of the batch. + """ + assert not self._shared_critic + assert isinstance(next_actions, list) and all(isinstance(action, torch.Tensor) for action in next_actions) + + states = ndarray_to_tensor(batch.states, device=self._device) # x + actions = [ndarray_to_tensor(action, device=self._device) for action in batch.actions] # a + next_states = ndarray_to_tensor(batch.next_states, device=self._device) # x' + rewards = ndarray_to_tensor(np.vstack([reward for reward in batch.rewards]), device=self._device) # r + terminals = ndarray_to_tensor(batch.terminals, device=self._device) # d + + self._q_critic_net.train() + with torch.no_grad(): + next_q_values = self._target_q_critic_net.q_values( + states=next_states, # x' + actions=next_actions, + ) # a' + target_q_values = ( + rewards[self._policy_idx] + self._reward_discount * (1 - terminals.float()) * next_q_values + ) + q_values = self._q_critic_net.q_values( + states=states, # x + actions=actions, # a + ) # Q(x, a) + return self._q_value_loss_func(q_values, target_q_values.detach()) + + @remote + def get_critic_grad( + self, + batch: MultiTransitionBatch, + next_actions: List[torch.Tensor], + ) -> Dict[str, torch.Tensor]: + """Compute the critic network's gradients of a batch. + + Args: + batch (MultiTransitionBatch): Batch. + next_actions (List[torch.Tensor]): List of next actions of all policies. + + Returns: + grad (torch.Tensor): The critic gradient of the batch. + """ + return self._q_critic_net.get_gradients(self._get_critic_loss(batch, next_actions)) + + def update_critic(self, batch: MultiTransitionBatch, next_actions: List[torch.Tensor]) -> None: + """Update the critic network using a batch. + + Args: + batch (MultiTransitionBatch): Batch. + next_actions (List[torch.Tensor]): List of next actions of all policies. + """ + self._q_critic_net.train() + self._q_critic_net.step(self._get_critic_loss(batch, next_actions)) + + def update_critic_with_grad(self, grad_dict: dict) -> None: + """Update the critic network with remotely computed gradients. + + Args: + grad_dict (dict): Gradients. + """ + self._q_critic_net.train() + self._q_critic_net.apply_gradients(grad_dict) + + def _get_actor_loss(self, batch: MultiTransitionBatch) -> torch.Tensor: + """Compute the actor loss of the batch. + + Args: + batch (MultiTransitionBatch): Batch. + + Returns: + loss (torch.Tensor): The actor loss of the batch. + """ + latest_action, latest_action_logp = self.get_latest_action(batch) + states = ndarray_to_tensor(batch.states, device=self._device) # x + actions = [ndarray_to_tensor(action, device=self._device) for action in batch.actions] # a + actions[self._policy_idx] = latest_action + self._policy.train() + self._q_critic_net.freeze() + actor_loss = -(self._q_critic_net.q_values( + states=states, # x + actions=actions, # [a^j_1, ..., a_i, ..., a^j_N] + ) * latest_action_logp).mean() # Q(x, a^j_1, ..., a_i, ..., a^j_N) + self._q_critic_net.unfreeze() + return actor_loss + + @remote + def get_actor_grad(self, batch: MultiTransitionBatch) -> Dict[str, torch.Tensor]: + """Compute the actor network's gradients of a batch. + + Args: + batch (MultiTransitionBatch): Batch. + + Returns: + grad (torch.Tensor): The actor gradient of the batch. + """ + return self._policy.get_gradients(self._get_actor_loss(batch)) + + def update_actor(self, batch: MultiTransitionBatch) -> None: + """Update the actor network using a batch. + + Args: + batch (MultiTransitionBatch): Batch. + """ + self._policy.train() + self._policy.train_step(self._get_actor_loss(batch)) + + def update_actor_with_grad(self, grad_dict: dict) -> None: + """Update the critic network with remotely computed gradients. + + Args: + grad_dict (dict): Gradients. + """ + self._policy.train() + self._policy.apply_gradients(grad_dict) + + def soft_update_target(self) -> None: + """Soft update the target policies and target critics. + """ + if self._policy_creator: + self._target_policy.soft_update(self._policy, self._soft_update_coef) + if not self._shared_critic: + self._target_q_critic_net.soft_update(self._q_critic_net, self._soft_update_coef) + + def get_critic_state(self) -> dict: + return { + "critic": self._q_critic_net.get_state(), + "target_critic": self._target_q_critic_net.get_state(), + } + + def set_critic_state(self, ops_state_dict: dict) -> None: + self._q_critic_net.set_state(ops_state_dict["critic"]) + self._target_q_critic_net.set_state(ops_state_dict["target_critic"]) + + def get_actor_state(self) -> dict: + if self._policy_creator: + return {"policy": self._policy.get_state(), "target_policy": self._target_policy.get_state()} + else: + return {} + + def set_actor_state(self, ops_state_dict: dict) -> None: + if self._policy_creator: + self._policy.set_state(ops_state_dict["policy"]) + self._target_policy.set_state(ops_state_dict["target_policy"]) + + def get_non_policy_state(self) -> dict: + return self.get_critic_state() + + def set_non_policy_state(self, state: dict) -> None: + self.set_critic_state(state) + + def to_device(self, device: str) -> None: + self._device = get_torch_device(device) + if self._policy_creator: + self._policy.to_device(self._device) + self._target_policy.to_device(self._device) + + self._q_critic_net.to(self._device) + self._target_q_critic_net.to(self._device) + + +class DiscreteMADDPGTrainer(MultiAgentTrainer): + """Multi-agent deep deterministic policy gradient (MADDPG) algorithm adapted for discrete action space. + + See https://arxiv.org/abs/1706.02275 for details. + """ + + def __init__(self, name: str, params: DiscreteMADDPGParams) -> None: + super(DiscreteMADDPGTrainer, self).__init__(name, params) + self._params = params + self._ops_params = self._params.extract_ops_params() + self._state_dim = params.get_q_critic_net_func().state_dim + self._policy_version = self._target_policy_version = 0 + self._shared_critic_ops_name = f"{self._name}.shared_critic" + + self._actor_ops_list = [] + self._critic_ops = None + self._replay_memory = None + self._policy2agent = {} + + def build(self) -> None: + for policy_name in self._policy_creator: + self._ops_dict[policy_name] = self.get_ops(policy_name) + + self._actor_ops_list = list(self._ops_dict.values()) + + if self._params.shared_critic: + self._ops_dict[self._shared_critic_ops_name] = self.get_ops(self._shared_critic_ops_name) + self._critic_ops = self._ops_dict[self._shared_critic_ops_name] + + self._replay_memory = RandomMultiReplayMemory( + capacity=self._params.replay_memory_capacity, + state_dim=self._state_dim, + action_dims=[ops.policy_action_dim for ops in self._actor_ops_list], + agent_states_dims=[ops.policy_state_dim for ops in self._actor_ops_list], + ) + + assert len(self._agent2policy.keys()) == len(self._agent2policy.values()) # agent <=> policy + self._policy2agent = {policy_name: agent_name for agent_name, policy_name in self._agent2policy.items()} + + def record_multiple(self, env_idx: int, exp_elements: List[ExpElement]) -> None: + terminal_flags: List[bool] = [] + for exp_element in exp_elements: + assert exp_element.num_agents == len(self._agent2policy.keys()) + + if min(exp_element.terminal_dict.values()) != max(exp_element.terminal_dict.values()): + raise ValueError("The 'terminal` flag of all agents at every tick must be identical.") + terminal_flags.append(min(exp_element.terminal_dict.values())) + + actions: List[np.ndarray] = [] + rewards: List[np.ndarray] = [] + agent_states: List[np.ndarray] = [] + next_agent_states: List[np.ndarray] = [] + for policy_name in self._policy_names: + agent_name = self._policy2agent[policy_name] + actions.append(np.vstack([exp_element.action_dict[agent_name] for exp_element in exp_elements])) + rewards.append(np.array([exp_element.reward_dict[agent_name] for exp_element in exp_elements])) + agent_states.append(np.vstack([exp_element.agent_state_dict[agent_name] for exp_element in exp_elements])) + next_agent_states.append(np.vstack( + [ + exp_element.next_agent_state_dict.get(agent_name, exp_element.agent_state_dict[agent_name]) + for exp_element in exp_elements + ] + )) + + transition_batch = MultiTransitionBatch( + states=np.vstack([exp_element.state for exp_element in exp_elements]), + actions=actions, + rewards=rewards, + next_states=np.vstack( + [ + exp_element.next_state if exp_element.next_state is not None else exp_element.state + for exp_element in exp_elements + ] + ), + agent_states=agent_states, + next_agent_states=next_agent_states, + terminals=np.array(terminal_flags), + ) + self._replay_memory.put(transition_batch) + + def get_local_ops(self, name: str) -> AbsTrainOps: + if name == self._shared_critic_ops_name: + ops_params = dict(self._ops_params) + ops_params.update({ + "policy_idx": -1, + "shared_critic": False, + }) + return DiscreteMADDPGOps(name=name, **ops_params) + else: + ops_params = dict(self._ops_params) + ops_params.update({ + "policy_creator": self._policy_creator[name], + "policy_idx": self._policy_names.index(name), + }) + return DiscreteMADDPGOps(name=name, **ops_params) + + def _get_batch(self, batch_size: int = None) -> MultiTransitionBatch: + return self._replay_memory.sample(batch_size if batch_size is not None else self._batch_size) + + def train_step(self) -> None: + assert not self._params.shared_critic or isinstance(self._critic_ops, DiscreteMADDPGOps) + assert all(isinstance(ops, DiscreteMADDPGOps) for ops in self._actor_ops_list) + for _ in range(self._params.num_epoch): + batch = self._get_batch() + # Collect next actions + next_actions = [ops.get_target_action(batch) for ops in self._actor_ops_list] + + # Update critic + if self._params.shared_critic: + self._critic_ops.update_critic(batch, next_actions) + critic_state_dict = self._critic_ops.get_critic_state() + # Sync latest critic to ops + for ops in self._actor_ops_list: + ops.set_critic_state(critic_state_dict) + else: + for ops in self._actor_ops_list: + ops.update_critic(batch, next_actions) + + # Update actors + for ops in self._actor_ops_list: + ops.update_actor(batch) + + # Update version + self._try_soft_update_target() + + async def train_step_as_task(self) -> None: + assert not self._params.shared_critic or isinstance(self._critic_ops, RemoteOps) + assert all(isinstance(ops, RemoteOps) for ops in self._actor_ops_list) + for _ in range(self._params.num_epoch): + batch = self._get_batch() + # Collect next actions + next_actions = [ops.get_target_action(batch) for ops in self._actor_ops_list] + + # Update critic + if self._params.shared_critic: + critic_grad = await asyncio.gather(*[self._critic_ops.get_critic_grad(batch, next_actions)]) + assert isinstance(critic_grad, list) and isinstance(critic_grad[0], dict) + self._critic_ops.update_critic_with_grad(critic_grad[0]) + critic_state_dict = self._critic_ops.get_critic_state() + # Sync latest critic to ops + for ops in self._actor_ops_list: + ops.set_critic_state(critic_state_dict) + else: + critic_grad_list = await asyncio.gather( + *[ops.get_critic_grad(batch, next_actions) for ops in self._actor_ops_list] + ) + for ops, critic_grad in zip(self._actor_ops_list, critic_grad_list): + ops.update_critic_with_grad(critic_grad) + + # Update actors + actor_grad_list = await asyncio.gather(*[ops.get_actor_grad(batch) for ops in self._actor_ops_list]) + for ops, actor_grad in zip(self._actor_ops_list, actor_grad_list): + ops.update_actor_with_grad(actor_grad) + + # Update version + self._try_soft_update_target() + + def _try_soft_update_target(self) -> None: + """Soft update the target policies and target critics. + """ + self._policy_version += 1 + if self._policy_version - self._target_policy_version == self._params.update_target_every: + for ops in self._actor_ops_list: + ops.soft_update_target() + if self._params.shared_critic: + self._critic_ops.soft_update_target() + self._target_policy_version = self._policy_version + + def get_policy_state(self) -> Dict[str, object]: + self._assert_ops_exists() + ret_policy_state = {} + for ops in self._actor_ops_list: + policy_name, state = ops.get_policy_state() + ret_policy_state[policy_name] = state + return ret_policy_state + + def load(self, path: str) -> None: + self._assert_ops_exists() + + policy_state_dict = torch.load(os.path.join(path, f"{self.name}_policy.{FILE_SUFFIX}")) + non_policy_state_dict = torch.load(os.path.join(path, f"{self.name}_non_policy.{FILE_SUFFIX}")) + for ops_name in policy_state_dict: + self._ops_dict[ops_name].set_state({**policy_state_dict[ops_name], **non_policy_state_dict[ops_name]}) + + def save(self, path: str) -> None: + self._assert_ops_exists() + + trainer_state = {ops.name: ops.get_state() for ops in self._actor_ops_list} + if self._params.shared_critic: + trainer_state[self._critic_ops.name] = self._critic_ops.get_state() + + policy_state_dict = {ops_name: state["policy"] for ops_name, state in trainer_state.items()} + non_policy_state_dict = {ops_name: state["non_policy"] for ops_name, state in trainer_state.items()} + + torch.save(policy_state_dict, os.path.join(path, f"{self.name}_policy.{FILE_SUFFIX}")) + torch.save(non_policy_state_dict, os.path.join(path, f"{self.name}_non_policy.{FILE_SUFFIX}")) + + def _assert_ops_exists(self) -> None: + if not self._actor_ops_list: + raise ValueError("Call 'DiscreteMADDPG.build' to create actor ops first.") + if self._params.shared_critic and not self._critic_ops: + raise ValueError("Call 'DiscreteMADDPG.build' to create the critic ops first.") + + async def exit(self) -> None: + pass diff --git a/maro/rl/training/algorithms/ppo.py b/maro/rl/training/algorithms/ppo.py new file mode 100644 index 000000000..9ec892ecc --- /dev/null +++ b/maro/rl/training/algorithms/ppo.py @@ -0,0 +1,196 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from dataclasses import dataclass +from typing import Callable, Dict, Tuple + +import numpy as np +import torch +from torch.distributions import Categorical + +from maro.rl.model import VNet +from maro.rl.policy import DiscretePolicyGradient, RLPolicy +from maro.rl.training.algorithms.base import ACBasedOps, ACBasedParams, ACBasedTrainer +from maro.rl.utils import discount_cumsum, ndarray_to_tensor, TransitionBatch + + +@dataclass +class PPOParams(ACBasedParams): + """Mostly inherited from `ACBasedParams`. Please refer to the doc string of `ACBasedParams` + for more detailed information. + + clip_ratio (float, default=None): Clip ratio in the PPO algorithm (https://arxiv.org/pdf/1707.06347.pdf). + If it is None, the actor loss is calculated using the usual policy gradient theorem. + """ + clip_ratio: float = None + + def extract_ops_params(self) -> Dict[str, object]: + return { + "get_v_critic_net_func": self.get_v_critic_net_func, + "reward_discount": self.reward_discount, + "critic_loss_cls": self.critic_loss_cls, + "clip_ratio": self.clip_ratio, + "lam": self.lam, + "min_logp": self.min_logp, + "is_discrete_action": self.is_discrete_action, + } + + def __post_init__(self) -> None: + assert self.get_v_critic_net_func is not None + assert self.clip_ratio is not None + + +class DiscretePPOWithEntropyOps(ACBasedOps): + def __init__( + self, + name: str, + policy_creator: Callable[[], RLPolicy], + get_v_critic_net_func: Callable[[], VNet], + parallelism: int = 1, + reward_discount: float = 0.9, + critic_loss_cls: Callable = None, + clip_ratio: float = None, + lam: float = 0.9, + min_logp: float = None, + is_discrete_action: bool = True, + ) -> None: + super(DiscretePPOWithEntropyOps, self).__init__( + name=name, + policy_creator=policy_creator, + get_v_critic_net_func=get_v_critic_net_func, + parallelism=parallelism, + reward_discount=reward_discount, + critic_loss_cls=critic_loss_cls, + clip_ratio=clip_ratio, + lam=lam, + min_logp=min_logp, + is_discrete_action=is_discrete_action, + ) + assert is_discrete_action + assert isinstance(self._policy, DiscretePolicyGradient) + self._policy_old = self._policy_creator() + self.update_policy_old() + + def update_policy_old(self) -> None: + self._policy_old.set_state(self._policy.get_state()) + + def _get_critic_loss(self, batch: TransitionBatch) -> torch.Tensor: + """Compute the critic loss of the batch. + + Args: + batch (TransitionBatch): Batch. + + Returns: + loss (torch.Tensor): The critic loss of the batch. + """ + self._v_critic_net.train() + states = ndarray_to_tensor(batch.states, self._device) + state_values = self._v_critic_net.v_values(states) + + values = state_values.cpu().detach().numpy() + values = np.concatenate([values[1:], values[-1:]]) + returns = batch.rewards + np.where(batch.terminals, 0.0, 1.0) * self._reward_discount * values + # special care for tail state + returns[-1] = state_values[-1] + returns = ndarray_to_tensor(returns, self._device) + + return self._critic_loss_func(state_values.float(), returns.float()) + + def _get_actor_loss(self, batch: TransitionBatch) -> Tuple[torch.Tensor, bool]: + """Compute the actor loss of the batch. + + Args: + batch (TransitionBatch): Batch. + + Returns: + loss (torch.Tensor): The actor loss of the batch. + early_stop (bool): Early stop indicator. + """ + assert isinstance(self._policy, DiscretePolicyGradient) + self._policy.train() + + states = ndarray_to_tensor(batch.states, device=self._device) + actions = ndarray_to_tensor(batch.actions, device=self._device) + advantages = ndarray_to_tensor(batch.advantages, device=self._device) + logps_old = ndarray_to_tensor(batch.old_logps, device=self._device) + if self._is_discrete_action: + actions = actions.long() + + action_probs = self._policy.get_action_probs(states) + dist_entropy = Categorical(action_probs).entropy() + logps = torch.log(action_probs.gather(1, actions).squeeze()) + logps = torch.clamp(logps, min=self._min_logp, max=.0) + if self._clip_ratio is not None: + ratio = torch.exp(logps - logps_old) + clipped_ratio = torch.clamp(ratio, 1 - self._clip_ratio, 1 + self._clip_ratio) + actor_loss = -(torch.min(ratio * advantages, clipped_ratio * advantages)).float() + kl = (logps_old - logps).mean().item() + early_stop = (kl >= 0.01 * 1.5) # TODO + else: + actor_loss = -(logps * advantages).float() # I * delta * log pi(a|s) + early_stop = False + actor_loss = (actor_loss - 0.2 * dist_entropy).mean() + + return actor_loss, early_stop + + def preprocess_batch(self, batch: TransitionBatch) -> TransitionBatch: + """Preprocess the batch to get the returns & advantages. + + Args: + batch (TransitionBatch): Batch. + + Returns: + The updated batch. + """ + assert isinstance(batch, TransitionBatch) + # Preprocess returns + batch.returns = discount_cumsum(batch.rewards, self._reward_discount) + + # Preprocess advantages + states = ndarray_to_tensor(batch.states, self._device) + state_values = self._v_critic_net.v_values(states).cpu().detach().numpy() + values = np.concatenate([state_values[1:], np.zeros(1).astype(np.float32)]) + deltas = (batch.rewards + self._reward_discount * values - state_values) + # special care for tail state + deltas[-1] = 0.0 + batch.advantages = discount_cumsum(deltas, self._reward_discount * self._lam) + + if self._clip_ratio is not None: + self._policy_old.eval() + actions = ndarray_to_tensor(batch.actions, device=self._device).long() + batch.old_logps = self._policy_old.get_states_actions_logps(states, actions).detach().cpu().numpy() + self._policy_old.train() + + return batch + + +class PPOTrainer(ACBasedTrainer): + """PPO algorithm. + + References: + https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/ppo. + """ + + def __init__(self, name: str, params: PPOParams) -> None: + super(PPOTrainer, self).__init__(name, params) + + +class DiscretePPOWithEntropyTrainer(ACBasedTrainer): + def __init__(self, name: str, params: PPOParams) -> None: + super(DiscretePPOWithEntropyTrainer, self).__init__(name, params) + + def get_local_ops(self) -> DiscretePPOWithEntropyOps: + return DiscretePPOWithEntropyOps( + name=self._policy_name, + policy_creator=self._policy_creator, + parallelism=self._params.data_parallelism, + **self._params.extract_ops_params(), + ) + + def train_step(self) -> None: + assert isinstance(self._ops, DiscretePPOWithEntropyOps) + batch = self._get_batch() + for _ in range(self._params.grad_iters): + self._ops.update_critic(batch) + self._ops.update_actor(batch) + self._ops.update_policy_old() diff --git a/maro/rl/training/algorithms/sac.py b/maro/rl/training/algorithms/sac.py new file mode 100644 index 000000000..e03743001 --- /dev/null +++ b/maro/rl/training/algorithms/sac.py @@ -0,0 +1,233 @@ +from dataclasses import dataclass +from typing import Callable, Dict, Optional, Tuple + +import torch + +from maro.rl.model import QNet +from maro.rl.policy import ContinuousRLPolicy, RLPolicy +from maro.rl.training import AbsTrainOps, RandomReplayMemory, remote, RemoteOps, SingleAgentTrainer, TrainerParams +from maro.rl.utils import get_torch_device, ndarray_to_tensor, TransitionBatch +from maro.utils import clone + + +@dataclass +class SoftActorCriticParams(TrainerParams): + get_q_critic_net_func: Callable[[], QNet] = None + update_target_every: int = 5 + random_overwrite: bool = False + entropy_coef: float = 0.1 + num_epochs: int = 1 + n_start_train: int = 0 + q_value_loss_cls: Callable = None + soft_update_coef: float = 1.0 + + def __post_init__(self) -> None: + assert self.get_q_critic_net_func is not None + + def extract_ops_params(self) -> Dict[str, object]: + return { + "get_q_critic_net_func": self.get_q_critic_net_func, + "entropy_coef": self.entropy_coef, + "reward_discount": self.reward_discount, + "q_value_loss_cls": self.q_value_loss_cls, + "soft_update_coef": self.soft_update_coef, + } + + +class SoftActorCriticOps(AbsTrainOps): + def __init__( + self, + name: str, + policy_creator: Callable[[], RLPolicy], + get_q_critic_net_func: Callable[[], QNet], + parallelism: int = 1, + *, + entropy_coef: float, + reward_discount: float, + q_value_loss_cls: Callable = None, + soft_update_coef: float = 1.0, + ) -> None: + super(SoftActorCriticOps, self).__init__( + name=name, + policy_creator=policy_creator, + parallelism=parallelism, + ) + + assert isinstance(self._policy, ContinuousRLPolicy) + + self._q_net1 = get_q_critic_net_func() + self._q_net2 = get_q_critic_net_func() + self._target_q_net1: QNet = clone(self._q_net1) + self._target_q_net1.eval() + self._target_q_net2: QNet = clone(self._q_net2) + self._target_q_net2.eval() + + self._entropy_coef = entropy_coef + self._soft_update_coef = soft_update_coef + self._reward_discount = reward_discount + self._q_value_loss_func = q_value_loss_cls() if q_value_loss_cls is not None else torch.nn.MSELoss() + + def _get_critic_loss(self, batch: TransitionBatch) -> Tuple[torch.Tensor, torch.Tensor]: + self._q_net1.train() + states = ndarray_to_tensor(batch.states, device=self._device) # s + next_states = ndarray_to_tensor(batch.next_states, device=self._device) # s' + actions = ndarray_to_tensor(batch.actions, device=self._device) # a + rewards = ndarray_to_tensor(batch.rewards, device=self._device) # r + terminals = ndarray_to_tensor(batch.terminals, device=self._device) # d + + assert isinstance(self._policy, ContinuousRLPolicy) + + with torch.no_grad(): + next_actions, next_logps = self._policy.get_actions_with_logps(states) + q1 = self._target_q_net1.q_values(next_states, next_actions) + q2 = self._target_q_net2.q_values(next_states, next_actions) + q = torch.min(q1, q2) + y = rewards + self._reward_discount * (1.0 - terminals.float()) * (q - self._entropy_coef * next_logps) + + q1 = self._q_net1.q_values(states, actions) + q2 = self._q_net2.q_values(states, actions) + loss_q1 = self._q_value_loss_func(q1, y) + loss_q2 = self._q_value_loss_func(q2, y) + return loss_q1, loss_q2 + + @remote + def get_critic_grad(self, batch: TransitionBatch) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + loss_q1, loss_q2 = self._get_critic_loss(batch) + grad_q1 = self._q_net1.get_gradients(loss_q1) + grad_q2 = self._q_net2.get_gradients(loss_q2) + return grad_q1, grad_q2 + + def update_critic_with_grad(self, grad_dict1: dict, grad_dict2: dict) -> None: + self._q_net1.train() + self._q_net2.train() + self._q_net1.apply_gradients(grad_dict1) + self._q_net2.apply_gradients(grad_dict2) + + def update_critic(self, batch: TransitionBatch) -> None: + self._q_net1.train() + self._q_net2.train() + loss_q1, loss_q2 = self._get_critic_loss(batch) + self._q_net1.step(loss_q1) + self._q_net2.step(loss_q2) + + def _get_actor_loss(self, batch: TransitionBatch) -> torch.Tensor: + self._policy.train() + states = ndarray_to_tensor(batch.states, device=self._device) # s + actions, logps = self._policy.get_actions_with_logps(states) + q1 = self._q_net1.q_values(states, actions) + q2 = self._q_net2.q_values(states, actions) + q = torch.min(q1, q2) + + loss = (self._entropy_coef * logps - q).mean() + return loss + + @remote + def get_actor_grad(self, batch: TransitionBatch) -> Dict[str, torch.Tensor]: + return self._policy.get_gradients(self._get_actor_loss(batch)) + + def update_actor_with_grad(self, grad_dict: dict) -> None: + self._policy.train() + self._policy.apply_gradients(grad_dict) + + def update_actor(self, batch: TransitionBatch) -> None: + self._policy.train() + self._policy.train_step(self._get_actor_loss(batch)) + + def get_non_policy_state(self) -> dict: + return { + "q_net1": self._q_net1.get_state(), + "q_net2": self._q_net2.get_state(), + "target_q_net1": self._target_q_net1.get_state(), + "target_q_net2": self._target_q_net2.get_state(), + } + + def set_non_policy_state(self, state: dict) -> None: + self._q_net1.set_state(state["q_net1"]) + self._q_net2.set_state(state["q_net2"]) + self._target_q_net1.set_state(state["target_q_net1"]) + self._target_q_net2.set_state(state["target_q_net2"]) + + def soft_update_target(self) -> None: + self._target_q_net1.soft_update(self._q_net1, self._soft_update_coef) + self._target_q_net2.soft_update(self._q_net2, self._soft_update_coef) + + def to_device(self, device: str) -> None: + self._device = get_torch_device(device=device) + self._q_net1.to(self._device) + self._q_net2.to(self._device) + self._target_q_net1.to(self._device) + self._target_q_net2.to(self._device) + + +class SoftActorCriticTrainer(SingleAgentTrainer): + def __init__(self, name: str, params: SoftActorCriticParams) -> None: + super(SoftActorCriticTrainer, self).__init__(name, params) + self._params = params + self._qnet_version = self._target_qnet_version = 0 + + self._replay_memory: Optional[RandomReplayMemory] = None + + def build(self) -> None: + self._ops = self.get_ops() + self._replay_memory = RandomReplayMemory( + capacity=self._params.replay_memory_capacity, + state_dim=self._ops.policy_state_dim, + action_dim=self._ops.policy_action_dim, + random_overwrite=self._params.random_overwrite, + ) + + def train_step(self) -> None: + assert isinstance(self._ops, SoftActorCriticOps) + + if self._replay_memory.n_sample < self._params.n_start_train: + print( + f"Skip this training step due to lack of experiences " + f"(current = {self._replay_memory.n_sample}, minimum = {self._params.n_start_train})" + ) + return + + for _ in range(self._params.num_epochs): + batch = self._get_batch() + self._ops.update_critic(batch) + self._ops.update_actor(batch) + + self._try_soft_update_target() + + async def train_step_as_task(self) -> None: + assert isinstance(self._ops, RemoteOps) + + if self._replay_memory.n_sample < self._params.n_start_train: + print( + f"Skip this training step due to lack of experiences " + f"(current = {self._replay_memory.n_sample}, minimum = {self._params.n_start_train})" + ) + return + + for _ in range(self._params.num_epochs): + batch = self._get_batch() + self._ops.update_critic_with_grad(await self._ops.get_critic_grad(batch)) + self._ops.update_actor_with_grad(await self._ops.get_actor_grad(batch)) + + self._try_soft_update_target() + + def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch: + return transition_batch + + def get_local_ops(self) -> SoftActorCriticOps: + return SoftActorCriticOps( + name=self._policy_name, + policy_creator=self._policy_creator, + parallelism=self._params.data_parallelism, + **self._params.extract_ops_params(), + ) + + def _get_batch(self, batch_size: int = None) -> TransitionBatch: + return self._replay_memory.sample(batch_size if batch_size is not None else self._batch_size) + + def _try_soft_update_target(self) -> None: + """Soft update the target policy and target critic. + """ + self._qnet_version += 1 + if self._qnet_version - self._target_qnet_version == self._params.update_target_every: + self._ops.soft_update_target() + self._target_qnet_version = self._qnet_version diff --git a/maro/rl/training/learner.py b/maro/rl/training/learner.py deleted file mode 100644 index 58422718d..000000000 --- a/maro/rl/training/learner.py +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from abc import ABC, abstractmethod -from os import getcwd -from typing import Union - -from numpy import asarray - -from maro.rl.agent import AbsAgent, MultiAgentWrapper -from maro.rl.scheduling import Scheduler -from maro.rl.storage import SimpleStore -from maro.rl.utils import ExperienceCollectionUtils -from maro.utils import Logger - -from .actor import Actor -from .actor_proxy import ActorProxy - - -class AbsLearner(ABC): - """Learner class. - - Args: - actor (Union[Actor, ActorProxy]): ``Actor`` or ``ActorProxy`` instance responsible for collecting roll-out - data for learning purposes. If it is an ``Actor``, it will perform roll-outs locally. If it is an - ``ActorProxy``, it will coordinate a set of remote actors to perform roll-outs in parallel. - agent (Union[AbsAgent, MultiAgentWrapper]): Learning agents. If None, the actor must be an ``Actor`` that - contains actual agents, rather than an ``ActorProxy``. Defaults to None. - """ - def __init__( - self, - actor: Union[Actor, ActorProxy], - agent: Union[AbsAgent, MultiAgentWrapper] = None, - log_dir: str = getcwd() - ): - super().__init__() - if isinstance(actor, ActorProxy): - assert agent, "agent cannot be None when the actor is a proxy." - self.agent = agent - else: - # The agent passed to __init__ is ignored in this case - self.agent = actor.agent - self.actor = actor - self.logger = Logger("LEARNER", dump_folder=log_dir) - - @abstractmethod - def run(self): - """Main learning loop is implemented here.""" - return NotImplementedError - - -class OnPolicyLearner(AbsLearner): - def __init__( - self, - actor: Union[Actor, ActorProxy], - max_episode: int, - agent: Union[AbsAgent, MultiAgentWrapper] = None, - log_dir: str = getcwd() - ): - super().__init__(actor, agent=agent, log_dir=log_dir) - self.max_episode = max_episode - - def run(self): - for ep in range(self.max_episode): - env_metrics, exp = self.actor.roll_out( - ep, model_by_agent=self.agent.dump_model() if isinstance(self.actor, ActorProxy) else None - ) - self.logger.info(f"ep-{ep}: {env_metrics}") - exp = ExperienceCollectionUtils.stack( - exp, - is_single_source=isinstance(self.actor, Actor), - is_single_agent=isinstance(self.agent, AbsAgent) - ) - if isinstance(self.agent, AbsAgent): - for e in exp: - self.agent.learn(*e["args"], **e.get("kwargs", {})) - else: - for agent_id, ex in exp.items(): - for e in ex: - self.agent[agent_id].learn(*e["args"], **e.get("kwargs", {})) - - self.logger.info("Agent learning finished") - - # Signal remote actors to quit - if isinstance(self.actor, ActorProxy): - self.actor.terminate() - - -MAX_LOSS = 1e8 - - -class OffPolicyLearner(AbsLearner): - def __init__( - self, - actor: Union[Actor, ActorProxy], - scheduler: Scheduler, - agent: Union[AbsAgent, MultiAgentWrapper] = None, - train_iter: int = 1, - min_experiences_to_train: int = 0, - batch_size: int = 128, - prioritized_sampling_by_loss: bool = False, - log_dir: str = getcwd() - ): - super().__init__(actor, agent=agent, log_dir=log_dir) - self.scheduler = scheduler - if isinstance(self.agent, AbsAgent): - self.experience_pool = SimpleStore(["S", "A", "R", "S_", "loss"]) - else: - self.experience_pool = { - agent: SimpleStore(["S", "A", "R", "S_", "loss"]) for agent in self.agent.agent_dict - } - self.train_iter = train_iter - self.min_experiences_to_train = min_experiences_to_train - self.batch_size = batch_size - self.prioritized_sampling_by_loss = prioritized_sampling_by_loss - - def run(self): - for exploration_params in self.scheduler: - rollout_index = self.scheduler.iter - env_metrics, exp = self.actor.roll_out( - rollout_index, - model_by_agent=self.agent.dump_model() if isinstance(self.actor, ActorProxy) else None, - exploration_params=exploration_params - ) - self.logger.info(f"ep-{rollout_index}: {env_metrics} ({exploration_params})") - - # store experiences in the experience pool. - exp = ExperienceCollectionUtils.concat( - exp, - is_single_source=isinstance(self.actor, Actor), - is_single_agent=isinstance(self.agent, AbsAgent) - ) - if isinstance(self.agent, AbsAgent): - exp.update({"loss": [MAX_LOSS] * len(list(exp.values())[0])}) - self.experience_pool.put(exp) - for i in range(self.train_iter): - batch, idx = self.get_batch() - loss = self.agent.learn(*batch) - self.experience_pool.update(idx, {"loss": list(loss)}) - else: - for agent_id, ex in exp.items(): - # ensure new experiences are sampled with the highest priority - ex.update({"loss": [MAX_LOSS] * len(list(ex.values())[0])}) - self.experience_pool[agent_id].put(ex) - - for i in range(self.train_iter): - batch_by_agent, idx_by_agent = self.get_batch() - loss_by_agent = { - agent_id: self.agent[agent_id].learn(*batch) for agent_id, batch in batch_by_agent.items() - } - for agent_id, loss in loss_by_agent.items(): - self.experience_pool[agent_id].update(idx_by_agent[agent_id], {"loss": list(loss)}) - - self.logger.info("Agent learning finished") - - # Signal remote actors to quit - if isinstance(self.actor, ActorProxy): - self.actor.terminate() - - def get_batch(self): - if isinstance(self.agent, AbsAgent): - if len(self.experience_pool) < self.min_experiences_to_train: - return None, None - if self.prioritized_sampling_by_loss: - indexes, sample = self.experience_pool.sample_by_key("loss", self.batch_size) - else: - indexes, sample = self.experience_pool.sample(self.batch_size) - batch = asarray(sample["S"]), asarray(sample["A"]), asarray(sample["R"]), asarray(sample["S_"]) - return batch, indexes - else: - idx, batch = {}, {} - for agent_id, pool in self.experience_pool.items(): - if len(pool) < self.min_experiences_to_train: - continue - if self.prioritized_sampling_by_loss: - indexes, sample = self.experience_pool[agent_id].sample_by_key("loss", self.batch_size) - else: - indexes, sample = self.experience_pool[agent_id].sample(self.batch_size) - batch[agent_id] = ( - asarray(sample["S"]), asarray(sample["A"]), asarray(sample["R"]), asarray(sample["S_"]) - ) - idx[agent_id] = indexes - - return batch, idx diff --git a/maro/rl/training/message_enums.py b/maro/rl/training/message_enums.py deleted file mode 100644 index 44302d47e..000000000 --- a/maro/rl/training/message_enums.py +++ /dev/null @@ -1,24 +0,0 @@ -from enum import Enum - - -class MessageTag(Enum): - ROLLOUT = "rollout" - CHOOSE_ACTION = "choose_action" - ACTION = "action" - ABORT_ROLLOUT = "abort_rollout" - TRAIN = "train" - FINISHED = "finished" - EXIT = "exit" - - -class PayloadKey(Enum): - ACTION = "action" - AGENT_ID = "agent_id" - ROLLOUT_INDEX = "rollout_index" - TIME_STEP = "time_step" - METRICS = "metrics" - DETAILS = "details" - STATE = "state" - TRAINING = "training" - MODEL = "model" - EXPLORATION_PARAMS = "exploration_params" diff --git a/maro/rl/training/proxy.py b/maro/rl/training/proxy.py new file mode 100644 index 000000000..29eaaed7a --- /dev/null +++ b/maro/rl/training/proxy.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from collections import defaultdict, deque + +from maro.rl.distributed import AbsProxy +from maro.rl.utils.common import bytes_to_pyobj, pyobj_to_bytes +from maro.rl.utils.torch_utils import average_grads +from maro.utils import LoggerV2 + + +class TrainingProxy(AbsProxy): + """Intermediary between trainers and workers. + + The proxy receives compute tasks from multiple ``AbsTrainOps`` instances, forwards them to a set of back-end + ``TrainOpsWorker``s to be processed and returns the results to the clients. + + Args: + frontend_port (int, default=10000): Network port for communicating with clients (task producers). + backend_port (int, default=10001): Network port for communicating with back-end workers (task consumers). + """ + + def __init__(self, frontend_port: int = 10000, backend_port: int = 10001) -> None: + super(TrainingProxy, self).__init__(frontend_port=frontend_port, backend_port=backend_port) + self._available_workers = deque() + self._worker_ready = False + self._connected_ops = set() + self._result_cache = defaultdict(list) + self._expected_num_results = {} + self._logger = LoggerV2("TRAIN-PROXY") + + def _route_request_to_compute_node(self, msg: list) -> None: + """ + Here we use a least-recently-used (LRU) routing strategy to select workers for a task while making the best + effort to satisfy the task's desired parallelism. For example, consider a task that specifies a desired + parallelism K (for gradient computation). If there are more than K workers in the ``_available_workers`` queue, + the first, i.e., the least recently used, K of them will be selected to process the task. If there are fewer + than K workers in the queue, all workers will be popped from the queue to process the task. In this case, the + desired parallelism cannot be satisfied, but waiting is avoided. + """ + if msg[-1] == b"EXIT": + self._connected_ops.remove(msg[0]) + # if all clients (ops) have signaled exit, tell the workers to terminate + if not self._connected_ops: + for worker_id in self._available_workers: + self._dispatch_endpoint.send_multipart([worker_id, b"EXIT"]) + return + + self._connected_ops.add(msg[0]) + req = bytes_to_pyobj(msg[-1]) + desired_parallelism = req["desired_parallelism"] + req["args"] = list(req["args"]) + batch = req["args"][0] + workers = [] + while len(workers) < desired_parallelism and self._available_workers: + workers.append(self._available_workers.popleft()) + + self._expected_num_results[msg[0]] = len(workers) + for worker_id, sub_batch in zip(workers, batch.split(len(workers))): + req["args"][0] = sub_batch + self._dispatch_endpoint.send_multipart([worker_id, msg[0], pyobj_to_bytes(req)]) + + if not self._available_workers: + # stop receiving compute requests until at least one worker becomes available + self._workers_ready = False + self._req_endpoint.stop_on_recv() + + def _send_result_to_requester(self, msg: list) -> None: + if msg[1] == b"EXIT_ACK": + self._logger.info("Exiting event loop...") + self.stop() + return + + if msg[1] != b"READY": + ops_name = msg[1] + self._result_cache[ops_name].append(bytes_to_pyobj(msg[-1])) + if len(self._result_cache[ops_name]) == self._expected_num_results[ops_name]: + aggregated_result = average_grads(self._result_cache[ops_name]) + self._logger.info(f"Aggregated {len(self._result_cache[ops_name])} results for {ops_name}") + self._result_cache[ops_name].clear() + self._req_endpoint.send_multipart([ops_name, pyobj_to_bytes(aggregated_result)]) + + self._available_workers.append(msg[0]) + self._worker_ready = True + self._req_endpoint.on_recv(self._route_request_to_compute_node) diff --git a/maro/rl/training/replay_memory.py b/maro/rl/training/replay_memory.py new file mode 100644 index 000000000..0fb6da63d --- /dev/null +++ b/maro/rl/training/replay_memory.py @@ -0,0 +1,504 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import ABCMeta, abstractmethod +from typing import List + +import numpy as np + +from maro.rl.utils import match_shape, MultiTransitionBatch, SHAPE_CHECK_FLAG, TransitionBatch + + +class AbsIndexScheduler(object, metaclass=ABCMeta): + """Scheduling indexes for read and write requests. This is used as an inner module of the replay memory. + + Args: + capacity (int): Maximum capacity of the replay memory. + """ + + def __init__(self, capacity: int) -> None: + super(AbsIndexScheduler, self).__init__() + self._capacity = capacity + + @abstractmethod + def get_put_indexes(self, batch_size: int) -> np.ndarray: + """Generate a list of indexes to the replay memory for writing. In other words, when the replay memory + need to write a batch, the scheduler should provide a set of proper indexes for the replay memory to + write. + + Args: + batch_size (int): The required batch size. + + Returns: + indexes (np.ndarray): The list of indexes. + """ + raise NotImplementedError + + @abstractmethod + def get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray: + """Generate a list of indexes that can be used to retrieve items from the replay memory. + + Args: + batch_size (int, default=None): The required batch size. If it is None, all indexes where an experience + item is present are returned. + forbid_last (bool, default=False): Whether the latest element is allowed to be sampled. + If this is true, the last index will always be excluded from the result. + + Returns: + indexes (np.ndarray): The list of indexes. + """ + raise NotImplementedError + + @abstractmethod + def get_last_index(self) -> int: + """Get the index of the latest element in the memory. + + Returns: + index (int): The index of the latest element in the memory. + """ + raise NotImplementedError + + +class RandomIndexScheduler(AbsIndexScheduler): + """Index scheduler that returns random indexes when sampling. + + Args: + capacity (int): Maximum capacity of the replay memory. + random_overwrite (bool): Flag that controls the overwriting behavior when the replay memory reaches capacity. + If this is true, newly added items will randomly overwrite existing ones. Otherwise, the overwrite occurs + in a cyclic manner. + """ + + def __init__(self, capacity: int, random_overwrite: bool) -> None: + super(RandomIndexScheduler, self).__init__(capacity) + self._random_overwrite = random_overwrite + self._ptr = self._size = 0 + + def get_put_indexes(self, batch_size: int) -> np.ndarray: + if self._ptr + batch_size <= self._capacity: + indexes = np.arange(self._ptr, self._ptr + batch_size) + self._ptr += batch_size + else: + overwrites = self._ptr + batch_size - self._capacity + indexes = np.concatenate([ + np.arange(self._ptr, self._capacity), + np.random.choice(self._ptr, size=overwrites, replace=False) if self._random_overwrite + else np.arange(overwrites) + ]) + self._ptr = self._capacity if self._random_overwrite else overwrites + + self._size = min(self._size + batch_size, self._capacity) + return indexes + + def get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray: + assert batch_size is not None and batch_size > 0, f"Invalid batch size: {batch_size}" + assert self._size > 0, "Cannot sample from an empty memory." + return np.random.choice(self._size, size=batch_size, replace=True) + + def get_last_index(self) -> int: + raise NotImplementedError + + +class FIFOIndexScheduler(AbsIndexScheduler): + """First-in-first-out index scheduler. + + Args: + capacity (int): Maximum capacity of the replay memory. + """ + + def __init__(self, capacity: int) -> None: + super(FIFOIndexScheduler, self).__init__(capacity) + self._head = self._tail = 0 + + @property + def size(self) -> int: + return (self._tail - self._head) % self._capacity + + def get_put_indexes(self, batch_size: int) -> np.ndarray: + if self.size + batch_size <= self._capacity: + if self._tail + batch_size <= self._capacity: + indexes = np.arange(self._tail, self._tail + batch_size) + else: + indexes = np.concatenate([ + np.arange(self._tail, self._capacity), + np.arange(self._tail + batch_size - self._capacity) + ]) + self._tail = (self._tail + batch_size) % self._capacity + return indexes + else: + overwrite = self.size + batch_size - self._capacity + self._head = (self._head + overwrite) % self._capacity + return self.get_put_indexes(batch_size) + + def get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray: + tmp = self._tail if not forbid_last else (self._tail - 1) % self._capacity + indexes = np.arange(self._head, tmp) if tmp > self._head \ + else np.concatenate([np.arange(self._head, self._capacity), np.arange(tmp)]) + self._head = tmp + return indexes + + def get_last_index(self) -> int: + return (self._tail - 1) % self._capacity + + +class AbsReplayMemory(object, metaclass=ABCMeta): + """Abstract replay memory class with basic interfaces. + + Args: + capacity (int): Maximum capacity of the replay memory. + state_dim (int): Dimension of states. + idx_scheduler (AbsIndexScheduler): The index scheduler. + """ + + def __init__(self, capacity: int, state_dim: int, idx_scheduler: AbsIndexScheduler) -> None: + super(AbsReplayMemory, self).__init__() + self._capacity = capacity + self._state_dim = state_dim + self._idx_scheduler = idx_scheduler + + @property + def capacity(self) -> int: + return self._capacity + + @property + def state_dim(self) -> int: + return self._state_dim + + def _get_put_indexes(self, batch_size: int) -> np.ndarray: + """Please refer to the doc string in AbsIndexScheduler. + """ + return self._idx_scheduler.get_put_indexes(batch_size) + + def _get_sample_indexes(self, batch_size: int = None, forbid_last: bool = False) -> np.ndarray: + """Please refer to the doc string in AbsIndexScheduler. + """ + return self._idx_scheduler.get_sample_indexes(batch_size, forbid_last) + + +class ReplayMemory(AbsReplayMemory, metaclass=ABCMeta): + """In-memory experience storage facility for a single trainer. + + Args: + capacity (int): Maximum capacity of the replay memory. + state_dim (int): Dimension of states. + action_dim (int): Dimension of actions. + idx_scheduler (AbsIndexScheduler): The index scheduler. + """ + + def __init__( + self, + capacity: int, + state_dim: int, + action_dim: int, + idx_scheduler: AbsIndexScheduler, + ) -> None: + super(ReplayMemory, self).__init__(capacity, state_dim, idx_scheduler) + self._action_dim = action_dim + + self._states = np.zeros((self._capacity, self._state_dim), dtype=np.float32) + self._actions = np.zeros((self._capacity, self._action_dim), dtype=np.float32) + self._rewards = np.zeros(self._capacity, dtype=np.float32) + self._terminals = np.zeros(self._capacity, dtype=np.bool) + self._next_states = np.zeros((self._capacity, self._state_dim), dtype=np.float32) + self._returns = np.zeros(self._capacity, dtype=np.float32) + self._advantages = np.zeros(self._capacity, dtype=np.float32) + self._old_logps = np.zeros(self._capacity, dtype=np.float32) + + self._n_sample = 0 + + @property + def action_dim(self) -> int: + return self._action_dim + + @property + def n_sample(self) -> int: + return self._n_sample + + def put(self, transition_batch: TransitionBatch) -> None: + """Store a transition batch in the memory. + + Args: + transition_batch (TransitionBatch): The transition batch. + """ + batch_size = len(transition_batch.states) + if SHAPE_CHECK_FLAG: + assert 0 < batch_size <= self._capacity + assert match_shape(transition_batch.states, (batch_size, self._state_dim)) + assert match_shape(transition_batch.actions, (batch_size, self._action_dim)) + assert match_shape(transition_batch.rewards, (batch_size,)) + assert match_shape(transition_batch.terminals, (batch_size,)) + assert match_shape(transition_batch.next_states, (batch_size, self._state_dim)) + if transition_batch.returns is not None: + match_shape(transition_batch.returns, (batch_size,)) + if transition_batch.advantages is not None: + match_shape(transition_batch.advantages, (batch_size,)) + if transition_batch.old_logps is not None: + match_shape(transition_batch.old_logps, (batch_size,)) + + self._put_by_indexes(self._get_put_indexes(batch_size), transition_batch) + self._n_sample = min(self._n_sample + transition_batch.size, self._capacity) + + def _put_by_indexes(self, indexes: np.ndarray, transition_batch: TransitionBatch) -> None: + """Store a transition batch into the memory at the give indexes. + + Args: + indexes (np.ndarray): Positions in the replay memory to store at. + transition_batch (TransitionBatch): The transition batch. + """ + self._states[indexes] = transition_batch.states + self._actions[indexes] = transition_batch.actions + self._rewards[indexes] = transition_batch.rewards + self._terminals[indexes] = transition_batch.terminals + self._next_states[indexes] = transition_batch.next_states + if transition_batch.returns is not None: + self._returns[indexes] = transition_batch.returns + if transition_batch.advantages is not None: + self._advantages[indexes] = transition_batch.advantages + if transition_batch.old_logps is not None: + self._old_logps[indexes] = transition_batch.old_logps + + def sample(self, batch_size: int = None) -> TransitionBatch: + """Generate a sample batch from the replay memory. + + Args: + batch_size (int, default=None): The required batch size. If it is None, all indexes where an experience + item is present are returned. + + Returns: + batch (TransitionBatch): The sampled batch. + """ + indexes = self._get_sample_indexes(batch_size, self._get_forbid_last()) + return self.sample_by_indexes(indexes) + + def sample_by_indexes(self, indexes: np.ndarray) -> TransitionBatch: + """Retrieve items at given indexes from the replay memory. + + Args: + indexes (np.ndarray): Positions in the replay memory to retrieve at. + + Returns: + batch (TransitionBatch): The sampled batch. + """ + assert all([0 <= idx < self._capacity for idx in indexes]) + + return TransitionBatch( + states=self._states[indexes], + actions=self._actions[indexes], + rewards=self._rewards[indexes], + terminals=self._terminals[indexes], + next_states=self._next_states[indexes], + returns=self._returns[indexes], + advantages=self._advantages[indexes], + old_logps=self._old_logps[indexes], + ) + + @abstractmethod + def _get_forbid_last(self) -> bool: + raise NotImplementedError + + +class RandomReplayMemory(ReplayMemory): + def __init__( + self, + capacity: int, + state_dim: int, + action_dim: int, + random_overwrite: bool = False, + ) -> None: + super(RandomReplayMemory, self).__init__( + capacity, state_dim, action_dim, RandomIndexScheduler(capacity, random_overwrite) + ) + self._random_overwrite = random_overwrite + self._scheduler = RandomIndexScheduler(capacity, random_overwrite) + + @property + def random_overwrite(self) -> bool: + return self._random_overwrite + + def _get_forbid_last(self) -> bool: + return False + + +class FIFOReplayMemory(ReplayMemory): + def __init__( + self, + capacity: int, + state_dim: int, + action_dim: int, + ) -> None: + super(FIFOReplayMemory, self).__init__( + capacity, state_dim, action_dim, FIFOIndexScheduler(capacity) + ) + + def _get_forbid_last(self) -> bool: + return not self._terminals[self._idx_scheduler.get_last_index()] + + +class MultiReplayMemory(AbsReplayMemory, metaclass=ABCMeta): + """In-memory experience storage facility for a multi trainer. + + Args: + capacity (int): Maximum capacity of the replay memory. + state_dim (int): Dimension of states. + action_dims (List[int]): Dimensions of actions. + idx_scheduler (AbsIndexScheduler): The index scheduler. + agent_states_dims (List[int]): Dimensions of agent states. + """ + + def __init__( + self, + capacity: int, + state_dim: int, + action_dims: List[int], + idx_scheduler: AbsIndexScheduler, + agent_states_dims: List[int], + ) -> None: + super(MultiReplayMemory, self).__init__(capacity, state_dim, idx_scheduler) + self._agent_num = len(action_dims) + self._action_dims = action_dims + + self._states = np.zeros((self._capacity, self._state_dim), dtype=np.float32) + self._actions = [np.zeros((self._capacity, action_dim), dtype=np.float32) for action_dim in self._action_dims] + self._rewards = [np.zeros(self._capacity, dtype=np.float32) for _ in range(self.agent_num)] + self._next_states = np.zeros((self._capacity, self._state_dim), dtype=np.float32) + self._terminals = np.zeros(self._capacity, dtype=np.bool) + + assert len(agent_states_dims) == self.agent_num + self._agent_states_dims = agent_states_dims + self._agent_states = [ + np.zeros((self._capacity, state_dim), dtype=np.float32) for state_dim in self._agent_states_dims + ] + self._next_agent_states = [ + np.zeros((self._capacity, state_dim), dtype=np.float32) for state_dim in self._agent_states_dims + ] + + @property + def action_dims(self) -> List[int]: + return self._action_dims + + @property + def agent_num(self) -> int: + return self._agent_num + + def put(self, transition_batch: MultiTransitionBatch) -> None: + """Store a transition batch into the memory. + + Args: + transition_batch (MultiTransitionBatch): The transition batch. + """ + batch_size = len(transition_batch.states) + if SHAPE_CHECK_FLAG: + assert 0 < batch_size <= self._capacity + assert match_shape(transition_batch.states, (batch_size, self._state_dim)) + assert len(transition_batch.actions) == len(transition_batch.rewards) == self.agent_num + for i in range(self.agent_num): + assert match_shape(transition_batch.actions[i], (batch_size, self.action_dims[i])) + assert match_shape(transition_batch.rewards[i], (batch_size,)) + + assert match_shape(transition_batch.terminals, (batch_size,)) + assert match_shape(transition_batch.next_states, (batch_size, self._state_dim)) + + assert len(transition_batch.agent_states) == self.agent_num + assert len(transition_batch.next_agent_states) == self.agent_num + for i in range(self.agent_num): + assert match_shape(transition_batch.agent_states[i], (batch_size, self._agent_states_dims[i])) + assert match_shape(transition_batch.next_agent_states[i], (batch_size, self._agent_states_dims[i])) + + self._put_by_indexes(self._get_put_indexes(batch_size), transition_batch=transition_batch) + + def _put_by_indexes(self, indexes: np.ndarray, transition_batch: MultiTransitionBatch) -> None: + """Store a transition batch into the memory at the give indexes. + + Args: + indexes (np.ndarray): Positions in the replay memory to store at. + transition_batch (MultiTransitionBatch): The transition batch. + """ + self._states[indexes] = transition_batch.states + for i in range(self.agent_num): + self._actions[i][indexes] = transition_batch.actions[i] + self._rewards[i][indexes] = transition_batch.rewards[i] + self._terminals[indexes] = transition_batch.terminals + + self._next_states[indexes] = transition_batch.next_states + for i in range(self.agent_num): + self._agent_states[i][indexes] = transition_batch.agent_states[i] + self._next_agent_states[i][indexes] = transition_batch.next_agent_states[i] + + def sample(self, batch_size: int = None) -> MultiTransitionBatch: + """Generate a sample batch from the replay memory. + + Args: + batch_size (int, default=None): The required batch size. If it is None, all indexes where an experience + item is present are returned. + + Returns: + batch (MultiTransitionBatch): The sampled batch. + """ + indexes = self._get_sample_indexes(batch_size, self._get_forbid_last()) + return self.sample_by_indexes(indexes) + + def sample_by_indexes(self, indexes: np.ndarray) -> MultiTransitionBatch: + """Retrieve items at given indexes from the replay memory. + + Args: + indexes (np.ndarray): Positions in the replay memory to retrieve at. + + Returns: + batch (MultiTransitionBatch): The sampled batch. + """ + assert all([0 <= idx < self._capacity for idx in indexes]) + + return MultiTransitionBatch( + states=self._states[indexes], + actions=[action[indexes] for action in self._actions], + rewards=[reward[indexes] for reward in self._rewards], + terminals=self._terminals[indexes], + next_states=self._next_states[indexes], + agent_states=[state[indexes] for state in self._agent_states], + next_agent_states=[state[indexes] for state in self._next_agent_states], + ) + + @abstractmethod + def _get_forbid_last(self) -> bool: + raise NotImplementedError + + +class RandomMultiReplayMemory(MultiReplayMemory): + def __init__( + self, + capacity: int, + state_dim: int, + action_dims: List[int], + agent_states_dims: List[int], + random_overwrite: bool = False, + ) -> None: + super(RandomMultiReplayMemory, self).__init__( + capacity, state_dim, action_dims, RandomIndexScheduler(capacity, random_overwrite), + agent_states_dims + ) + self._random_overwrite = random_overwrite + self._scheduler = RandomIndexScheduler(capacity, random_overwrite) + + @property + def random_overwrite(self) -> bool: + return self._random_overwrite + + def _get_forbid_last(self) -> bool: + return False + + +class FIFOMultiReplayMemory(MultiReplayMemory): + def __init__( + self, + capacity: int, + state_dim: int, + action_dims: List[int], + agent_states_dims: List[int], + ) -> None: + super(FIFOMultiReplayMemory, self).__init__( + capacity, state_dim, action_dims, FIFOIndexScheduler(capacity), + agent_states_dims, + ) + + def _get_forbid_last(self) -> bool: + return not self._terminals[self._idx_scheduler.get_last_index()] diff --git a/maro/rl/training/train_ops.py b/maro/rl/training/train_ops.py new file mode 100644 index 000000000..ddc306edd --- /dev/null +++ b/maro/rl/training/train_ops.py @@ -0,0 +1,246 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import inspect +from abc import ABCMeta, abstractmethod +from typing import Callable, Tuple + +import zmq +from zmq.asyncio import Context, Poller + +from maro.rl.policy import RLPolicy +from maro.rl.utils.common import bytes_to_pyobj, get_ip_address_by_hostname, pyobj_to_bytes +from maro.utils import DummyLogger, LoggerV2 + + +class AbsTrainOps(object, metaclass=ABCMeta): + """The basic component for training a policy, which takes charge of loss / gradient computation and policy update. + Each ops is used for training a single policy. An ops is an atomic unit in the distributed mode. + + Args: + name (str): Name of the ops. This is usually a policy name. + policy_creator (Callable[[], RLPolicy]): Function to create a policy instance. + parallelism (int, default=1): Desired degree of data parallelism. + """ + + def __init__( + self, + name: str, + policy_creator: Callable[[], RLPolicy], + parallelism: int = 1, + ) -> None: + super(AbsTrainOps, self).__init__() + self._name = name + self._policy_creator = policy_creator + # Create the policy. + if self._policy_creator: + self._policy = self._policy_creator() + + self._parallelism = parallelism + + @property + def name(self) -> str: + return self._name + + @property + def policy_state_dim(self) -> int: + return self._policy.state_dim if self._policy_creator else None + + @property + def policy_action_dim(self) -> int: + return self._policy.action_dim if self._policy_creator else None + + @property + def parallelism(self) -> int: + return self._parallelism + + def get_state(self) -> dict: + """Get the train ops's state. + + Returns: + A dict that contains ops's state. + """ + return { + "policy": self.get_policy_state(), + "non_policy": self.get_non_policy_state(), + } + + def set_state(self, ops_state_dict: dict) -> None: + """Set ops's state. + + Args: + ops_state_dict (dict): New ops state. + """ + assert ops_state_dict["policy"][0] == self._policy.name + self.set_policy_state(ops_state_dict["policy"][1]) + self.set_non_policy_state(ops_state_dict["non_policy"]) + + def get_policy_state(self) -> Tuple[str, object]: + """Get the policy's state. + + Returns: + policy_name (str) + policy_state (object) + """ + return self._policy.name, self._policy.get_state() + + def set_policy_state(self, policy_state: object) -> None: + """Update the policy's state. + + Args: + policy_state (object): The policy state. + """ + self._policy.set_state(policy_state) + + @abstractmethod + def get_non_policy_state(self) -> dict: + """Get states other than policy. + + Returns: + A dict that contains non-policy state. + """ + raise NotImplementedError + + @abstractmethod + def set_non_policy_state(self, state: dict) -> None: + """Set states other than policy. + + Args: + state (dict): Non-policy state. + """ + raise NotImplementedError + + @abstractmethod + def to_device(self, device: str): + raise NotImplementedError + + +def remote(func) -> Callable: + """Annotation to indicate that a function / method can be called remotely. + + This annotation takes effect only when an ``AbsTrainOps`` object is wrapped by a ``RemoteOps``. + """ + + def remote_annotate(*args, **kwargs) -> object: + return func(*args, **kwargs) + + return remote_annotate + + +class AsyncClient(object): + """Facility used by a ``RemoteOps`` instance to communicate asynchronously with ``TrainingProxy``. + + Args: + name (str): Name of the client. + address (Tuple[str, int]): Address (host and port) of the training proxy. + logger (LoggerV2, default=None): logger. + """ + + def __init__(self, name: str, address: Tuple[str, int], logger: LoggerV2 = None) -> None: + self._logger = DummyLogger() if logger is None else logger + self._name = name + host, port = address + self._proxy_ip = get_ip_address_by_hostname(host) + self._address = f"tcp://{self._proxy_ip}:{port}" + self._logger.info(f"Proxy address: {self._address}") + + async def send_request(self, req: dict) -> None: + """Send a request to the proxy in asynchronous fashion. + + This is a coroutine and is executed asynchronously with calls to other AsyncClients' ``send_request`` calls. + + Args: + req (dict): Request that contains task specifications and parameters. + """ + await self._socket.send(pyobj_to_bytes(req)) + self._logger.debug(f"{self._name} sent request {req['func']}") + + async def get_response(self) -> object: + """Waits for a result in asynchronous fashion. + + This is a coroutine and is executed asynchronously with calls to other AsyncClients' ``get_response`` calls. + This ensures that all clients' tasks are sent out as soon as possible before the waiting for results starts. + """ + while True: + events = await self._poller.poll(timeout=100) + if self._socket in dict(events): + result = await self._socket.recv_multipart() + self._logger.debug(f"{self._name} received result") + return bytes_to_pyobj(result[0]) + + def close(self) -> None: + """Close the connection to the proxy. + """ + self._poller.unregister(self._socket) + self._socket.disconnect(self._address) + self._socket.close() + + def connect(self) -> None: + """Establish the connection to the proxy. + """ + self._socket = Context.instance().socket(zmq.DEALER) + self._socket.setsockopt_string(zmq.IDENTITY, self._name) + self._socket.setsockopt(zmq.LINGER, 0) + self._socket.connect(self._address) + self._logger.debug(f"connected to {self._address}") + self._poller = Poller() + self._poller.register(self._socket, zmq.POLLIN) + + async def exit(self) -> None: + """Send EXIT signals to the proxy indicating no more tasks. + """ + await self._socket.send(b"EXIT") + + +class RemoteOps(object): + """Wrapper for ``AbsTrainOps``. + + RemoteOps provides similar interfaces to ``AbsTrainOps``. Any method annotated by the remote decorator in the + definition of the train ops is transformed to a remote method. Calling this method invokes using the internal + ``AsyncClient`` to send the required task parameters to a ``TrainingProxy`` that handles task dispatching and + result collection. Methods not annotated by the decorator are not affected. + + Args: + ops (AbsTrainOps): An ``AbsTrainOps`` instance to be wrapped. Any method annotated by the remote decorator in + its definition is transformed to a remote function call. + address (Tuple[str, int]): Address (host and port) of the training proxy. + logger (LoggerV2, default=None): logger. + """ + + def __init__(self, ops: AbsTrainOps, address: Tuple[str, int], logger: LoggerV2 = None) -> None: + self._ops = ops + self._client = AsyncClient(self._ops.name, address, logger=logger) + self._client.connect() + + def __getattribute__(self, attr_name: str) -> object: + # Ignore methods that belong to the parent class + try: + return super().__getattribute__(attr_name) + except AttributeError: + pass + + def remote_method(ops_state, func_name: str, desired_parallelism: int, client: AsyncClient) -> Callable: + async def remote_call(*args, **kwargs) -> object: + req = { + "state": ops_state, + "func": func_name, + "args": args, + "kwargs": kwargs, + "desired_parallelism": desired_parallelism, + } + await client.send_request(req) + response = await client.get_response() + return response + + return remote_call + + attr = getattr(self._ops, attr_name) + if inspect.ismethod(attr) and attr.__name__ == "remote_annotate": + return remote_method(self._ops.get_state(), attr_name, self._ops.parallelism, self._client) + + return attr + + async def exit(self) -> None: + """Close the internal task client. + """ + await self._client.exit() diff --git a/maro/rl/training/trainer.py b/maro/rl/training/trainer.py new file mode 100644 index 000000000..1dc0dac2c --- /dev/null +++ b/maro/rl/training/trainer.py @@ -0,0 +1,341 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import collections +import os +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch + +from maro.rl.policy import AbsPolicy, RLPolicy +from maro.rl.rollout import ExpElement +from maro.rl.utils import TransitionBatch +from maro.rl.utils.objects import FILE_SUFFIX +from maro.utils import LoggerV2 + +from .replay_memory import ReplayMemory +from .train_ops import AbsTrainOps, RemoteOps + + +@dataclass +class TrainerParams: + """Common trainer parameters. + + replay_memory_capacity (int, default=100000): Maximum capacity of the replay memory. + batch_size (int, default=128): Training batch size. + data_parallelism (int, default=1): Degree of data parallelism. A value greater than 1 can be used when + a model is large and computing gradients with respect to a batch becomes expensive. In this case, the + batch may be split into multiple smaller batches whose gradients can be computed in parallel on a set + of remote nodes. For simplicity, only synchronous parallelism is supported, meaning that the model gets + updated only after collecting all the gradients from the remote nodes. Note that this value is the desired + parallelism and the actual parallelism in a distributed experiment may be smaller depending on the + availability of compute resources. For details on distributed deep learning and data parallelism, see + https://web.stanford.edu/~rezab/classes/cme323/S16/projects_reports/hedge_usmani.pdf, as well as an abundance + of resources available on the internet. + reward_discount (float, default=0.9): Reward decay as defined in standard RL terminology. + + """ + replay_memory_capacity: int = 10000 + batch_size: int = 128 + data_parallelism: int = 1 + reward_discount: float = 0.9 + + @abstractmethod + def extract_ops_params(self) -> Dict[str, object]: + """Extract parameters that should be passed to the train ops. + + Returns: + params (Dict[str, object]): Parameter dict. + """ + raise NotImplementedError + + +class AbsTrainer(object, metaclass=ABCMeta): + """Policy trainer used to train policies. Trainer maintains a group of train ops and + controls training logics of them, while train ops take charge of specific policy updating. + + Trainer will hold one or more replay memories to store the experiences, and it will also maintain a duplication + of all policies it trains. However, trainer will not do any actual computations. All computations will be + done in the train ops. + + Args: + name (str): Name of the trainer. + params (TrainerParams): Trainer's parameters. + """ + + def __init__(self, name: str, params: TrainerParams) -> None: + self._name = name + self._params = params + self._batch_size = self._params.batch_size + self._agent2policy: Dict[Any, str] = {} + self._proxy_address: Optional[Tuple[str, int]] = None + self._logger = None + + @property + def name(self) -> str: + return self._name + + @property + def agent_num(self) -> int: + return len(self._agent2policy) + + def register_logger(self, logger: LoggerV2) -> None: + self._logger = logger + + def register_agent2policy(self, agent2policy: Dict[Any, str], policy_trainer_mapping: Dict[str, str]) -> None: + """Register the agent to policy dict that correspond to the current trainer. A valid policy name should start + with the name of its trainer. For example, "DQN.POLICY_NAME". Therefore, we could identify which policies + should be registered to the current trainer according to the policy's name. + + Args: + agent2policy (Dict[Any, str]): Agent name to policy name mapping. + policy_trainer_mapping (Dict[str, str]): Policy name to trainer name mapping. + """ + self._agent2policy = { + agent_name: policy_name for agent_name, policy_name in agent2policy.items() + if policy_trainer_mapping[policy_name] == self.name + } + + @abstractmethod + def register_policy_creator( + self, + global_policy_creator: Dict[str, Callable[[], AbsPolicy]], + policy_trainer_mapping: Dict[str, str], + ) -> None: + """Register the policy creator. Only keep the creators of the policies that the current trainer need to train. + + Args: + global_policy_creator (Dict[str, Callable[[], AbsPolicy]]): Dict that contains the creators for all + policies. + policy_trainer_mapping (Dict[str, str]): Policy name to trainer name mapping. + """ + raise NotImplementedError + + @abstractmethod + def build(self) -> None: + """Create the required train-ops and replay memory. This should be called before invoking `train` or + `train_as_task`. + """ + raise NotImplementedError + + @abstractmethod + def train_step(self) -> None: + """Run a training step to update all the policies that this trainer is responsible for. + """ + raise NotImplementedError + + async def train_step_as_task(self) -> None: + """Update all policies managed by the trainer as an asynchronous task. + """ + raise NotImplementedError + + @abstractmethod + def record_multiple(self, env_idx: int, exp_elements: List[ExpElement]) -> None: + """Record rollout all experiences from an environment in the replay memory. + + Args: + env_idx (int): The index of the environment that generates this batch of experiences. This is used + when there are more than one environment collecting experiences in parallel. + exp_elements (List[ExpElement]): Experiences. + """ + raise NotImplementedError + + def set_proxy_address(self, proxy_address: Tuple[str, int]) -> None: + self._proxy_address = proxy_address + + @abstractmethod + def get_policy_state(self) -> Dict[str, object]: + """Get policies' states. + + Returns: + A double-deck dict with format: {policy_name: policy_state}. + """ + raise NotImplementedError + + @abstractmethod + def load(self, path: str) -> None: + raise NotImplementedError + + @abstractmethod + def save(self, path: str) -> None: + raise NotImplementedError + + @abstractmethod + async def exit(self) -> None: + raise NotImplementedError + + +class SingleAgentTrainer(AbsTrainer, metaclass=ABCMeta): + """Policy trainer that trains only one policy. + """ + + def __init__(self, name: str, params: TrainerParams) -> None: + super(SingleAgentTrainer, self).__init__(name, params) + self._policy_name: Optional[str] = None + self._policy_creator: Optional[Callable[[], RLPolicy]] = None + self._ops: Optional[AbsTrainOps] = None + self._replay_memory: Optional[ReplayMemory] = None + + @property + def ops(self): + return self._ops + + def register_policy_creator( + self, + global_policy_creator: Dict[str, Callable[[], AbsPolicy]], + policy_trainer_mapping: Dict[str, str], + ) -> None: + policy_names = [ + policy_name + for policy_name in global_policy_creator + if policy_trainer_mapping[policy_name] == self.name + ] + if len(policy_names) != 1: + raise ValueError(f"Trainer {self._name} should have exactly one policy assigned to it") + + self._policy_name = policy_names.pop() + self._policy_creator = global_policy_creator[self._policy_name] + + @abstractmethod + def get_local_ops(self) -> AbsTrainOps: + """Create an `AbsTrainOps` instance associated with the policy. + + Returns: + ops (AbsTrainOps): The local ops. + """ + raise NotImplementedError + + def get_ops(self) -> Union[RemoteOps, AbsTrainOps]: + """Create an `AbsTrainOps` instance associated with the policy. If a proxy address has been registered to the + trainer, this returns a `RemoteOps` instance in which all methods annotated as "remote" are turned into a remote + method call. Otherwise, a regular `AbsTrainOps` is returned. + + Returns: + ops (Union[RemoteOps, AbsTrainOps]): The ops. + """ + ops = self.get_local_ops() + return RemoteOps(ops, self._proxy_address, logger=self._logger) if self._proxy_address else ops + + def get_policy_state(self) -> Dict[str, object]: + self._assert_ops_exists() + policy_name, state = self._ops.get_policy_state() + return {policy_name: state} + + def load(self, path: str) -> None: + self._assert_ops_exists() + + policy_state = torch.load(os.path.join(path, f"{self.name}_policy.{FILE_SUFFIX}")) + non_policy_state = torch.load(os.path.join(path, f"{self.name}_non_policy.{FILE_SUFFIX}")) + + self._ops.set_state({ + "policy": policy_state, + "non_policy": non_policy_state, + }) + + def save(self, path: str) -> None: + self._assert_ops_exists() + + ops_state = self._ops.get_state() + policy_state = ops_state["policy"] + non_policy_state = ops_state["non_policy"] + + torch.save(policy_state, os.path.join(path, f"{self.name}_policy.{FILE_SUFFIX}")) + torch.save(non_policy_state, os.path.join(path, f"{self.name}_non_policy.{FILE_SUFFIX}")) + + def record_multiple(self, env_idx: int, exp_elements: List[ExpElement]) -> None: + agent_exp_pool = collections.defaultdict(list) + for exp_element in exp_elements: + for agent_name in exp_element.agent_names: + agent_exp_pool[agent_name].append(( + exp_element.agent_state_dict[agent_name], + exp_element.action_dict[agent_name], + exp_element.reward_dict[agent_name], + exp_element.terminal_dict[agent_name], + exp_element.next_agent_state_dict.get(agent_name, exp_element.agent_state_dict[agent_name]), + )) + + for agent_name, exps in agent_exp_pool.items(): + transition_batch = TransitionBatch( + states=np.vstack([exp[0] for exp in exps]), + actions=np.vstack([exp[1] for exp in exps]), + rewards=np.array([exp[2] for exp in exps]), + terminals=np.array([exp[3] for exp in exps]), + next_states=np.vstack([exp[4] for exp in exps]), + ) + transition_batch = self._preprocess_batch(transition_batch) + self._replay_memory.put(transition_batch) + + @abstractmethod + def _preprocess_batch(self, transition_batch: TransitionBatch) -> TransitionBatch: + raise NotImplementedError + + def _assert_ops_exists(self) -> None: + if not self._ops: + raise ValueError("'build' needs to be called to create an ops instance first.") + + async def exit(self) -> None: + self._assert_ops_exists() + if isinstance(self._ops, RemoteOps): + await self._ops.exit() + + +class MultiAgentTrainer(AbsTrainer, metaclass=ABCMeta): + """Policy trainer that trains multiple policies. + """ + + def __init__(self, name: str, params: TrainerParams) -> None: + super(MultiAgentTrainer, self).__init__(name, params) + self._policy_creator: Dict[str, Callable[[], RLPolicy]] = {} + self._policy_names: List[str] = [] + self._ops_dict: Dict[str, AbsTrainOps] = {} + + @property + def ops_dict(self): + return self._ops_dict + + def register_policy_creator( + self, + global_policy_creator: Dict[str, Callable[[], AbsPolicy]], + policy_trainer_mapping: Dict[str, str], + ) -> None: + self._policy_creator: Dict[str, Callable[[], RLPolicy]] = { + policy_name: func for policy_name, func in global_policy_creator.items() + if policy_trainer_mapping[policy_name] == self.name + } + self._policy_names = list(self._policy_creator.keys()) + + @abstractmethod + def get_local_ops(self, name: str) -> AbsTrainOps: + """Create an `AbsTrainOps` instance with a given name. + + Args: + name (str): Ops name. + + Returns: + ops (AbsTrainOps): The local ops. + """ + raise NotImplementedError + + def get_ops(self, name: str) -> Union[RemoteOps, AbsTrainOps]: + """Create an `AbsTrainOps` instance with a given name. If a proxy address has been registered to the trainer, + this returns a `RemoteOps` instance in which all methods annotated as "remote" are turned into a remote method + call. Otherwise, a regular `AbsTrainOps` is returned. + + Args: + name (str): Ops name. + + Returns: + ops (Union[RemoteOps, AbsTrainOps]): The ops. + """ + ops = self.get_local_ops(name) + return RemoteOps(ops, self._proxy_address, logger=self._logger) if self._proxy_address else ops + + @abstractmethod + def get_policy_state(self) -> Dict[str, object]: + raise NotImplementedError + + @abstractmethod + async def exit(self) -> None: + raise NotImplementedError diff --git a/maro/rl/training/training_manager.py b/maro/rl/training/training_manager.py new file mode 100644 index 000000000..85b2c2c7d --- /dev/null +++ b/maro/rl/training/training_manager.py @@ -0,0 +1,140 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import asyncio +import collections +import os +import typing +from itertools import chain +from typing import Any, Dict, Iterable, List, Tuple + +from maro.rl.rollout import ExpElement +from maro.rl.training import SingleAgentTrainer +from maro.utils import LoggerV2 +from maro.utils.exception.rl_toolkit_exception import MissingTrainer + +from .trainer import AbsTrainer, MultiAgentTrainer + +if typing.TYPE_CHECKING: + from maro.rl.rl_component.rl_component_bundle import RLComponentBundle + + +class TrainingManager(object): + """ + Training manager. Manage and schedule all trainers to train policies. + + Args: + rl_component_bundle (RLComponentBundle): The RL component bundle of the job. + explicit_assign_device (bool): Whether to assign policy to its device in the training manager. + proxy_address (Tuple[str, int], default=None): Address of the training proxy. If it is not None, + it is registered to all trainers, which in turn create `RemoteOps` for distributed training. + logger (LoggerV2, default=None): A logger for logging key events. + """ + + def __init__( + self, + rl_component_bundle: RLComponentBundle, + explicit_assign_device: bool, + proxy_address: Tuple[str, int] = None, + logger: LoggerV2 = None, + ) -> None: + super(TrainingManager, self).__init__() + + self._trainer_dict: Dict[str, AbsTrainer] = {} + self._proxy_address = proxy_address + for trainer_name, func in rl_component_bundle.trainer_creator.items(): + trainer = func() + if self._proxy_address: + trainer.set_proxy_address(self._proxy_address) + trainer.register_agent2policy( + rl_component_bundle.trainable_agent2policy, + rl_component_bundle.policy_trainer_mapping, + ) + trainer.register_policy_creator( + rl_component_bundle.trainable_policy_creator, + rl_component_bundle.policy_trainer_mapping, + ) + trainer.register_logger(logger) + trainer.build() # `build()` must be called after `register_policy_creator()` + self._trainer_dict[trainer_name] = trainer + + # User-defined allocation of compute devices, i.e., GPU's to the trainer ops + if explicit_assign_device: + for policy_name, device_name in rl_component_bundle.device_mapping.items(): + if policy_name not in rl_component_bundle.policy_trainer_mapping: # No need to assign device + continue + + trainer = self._trainer_dict[rl_component_bundle.policy_trainer_mapping[policy_name]] + + if isinstance(trainer, SingleAgentTrainer): + ops = trainer.ops + else: + assert isinstance(trainer, MultiAgentTrainer) + ops = trainer.ops_dict[policy_name] + ops.to_device(device_name) + + self._agent2trainer: Dict[Any, str] = {} + for agent_name, policy_name in rl_component_bundle.trainable_agent2policy.items(): + trainer_name = rl_component_bundle.policy_trainer_mapping[policy_name] + if trainer_name not in self._trainer_dict: + raise MissingTrainer(f"trainer {trainer_name} does not exist") + self._agent2trainer[agent_name] = trainer_name + + def train_step(self) -> None: + if self._proxy_address: + async def train_step() -> Iterable: + return await asyncio.gather( + *[trainer_.train_step_as_task() for trainer_ in self._trainer_dict.values()] + ) + + asyncio.run(train_step()) + else: + for trainer in self._trainer_dict.values(): + trainer.train_step() + + def get_policy_state(self) -> Dict[str, Dict[str, object]]: + """Get policies' states. + + Returns: + A double-deck dict with format: {trainer_name: {policy_name: policy_state}} + """ + return dict(chain(*[trainer.get_policy_state().items() for trainer in self._trainer_dict.values()])) + + def record_experiences(self, experiences: List[List[ExpElement]]) -> None: + """Record experiences collected from external modules (for example, EnvSampler). + + Args: + experiences (List[ExpElement]): List of experiences. Each ExpElement stores the complete information for a + tick. Please refers to the definition of ExpElement for detailed explanation of ExpElement. + """ + for env_idx, env_experience in enumerate(experiences): + trainer_exp_pool = collections.defaultdict(list) + for exp_element in env_experience: # Dispatch experiences to trainers tick by tick. + exp_dict = exp_element.split_contents_by_trainer(self._agent2trainer) + for trainer_name, exp_elem in exp_dict.items(): + trainer_exp_pool[trainer_name].append(exp_elem) + + for trainer_name, exp_elems in trainer_exp_pool.items(): + trainer = self._trainer_dict[trainer_name] + trainer.record_multiple(env_idx, exp_elems) + + def load(self, path: str) -> List[str]: + loaded = [] + for trainer_name, trainer in self._trainer_dict.items(): + trainer.load(path) + loaded.append(trainer_name) + return loaded + + def save(self, path: str) -> None: + os.makedirs(path, exist_ok=True) + for trainer_name, trainer in self._trainer_dict.items(): + trainer.save(path) + + def exit(self) -> None: + if self._proxy_address: + async def exit_all() -> Iterable: + return await asyncio.gather(*[trainer.exit() for trainer in self._trainer_dict.values()]) + + asyncio.run(exit_all()) diff --git a/maro/rl/training/trajectory.py b/maro/rl/training/trajectory.py deleted file mode 100644 index cefac38e8..000000000 --- a/maro/rl/training/trajectory.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from collections import defaultdict - - -class Trajectory(object): - def __init__(self, env): - self.env = env - self.trajectory = defaultdict(list) - - def get_state(self, event) -> dict: - pass - - def get_action(self, action_by_agent, event) -> dict: - pass - - def get_reward(self) -> float: - pass - - def on_env_feedback(self, event, state_by_agent, action_by_agent, reward): - pass - - def on_finish(self): - pass - - def reset(self): - self.trajectory = defaultdict(list) diff --git a/maro/rl/training/worker.py b/maro/rl/training/worker.py new file mode 100644 index 000000000..f24d69cbf --- /dev/null +++ b/maro/rl/training/worker.py @@ -0,0 +1,83 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import typing +from typing import Dict + +from maro.rl.distributed import AbsWorker +from maro.rl.training import SingleAgentTrainer +from maro.rl.utils.common import bytes_to_pyobj, bytes_to_string, pyobj_to_bytes +from maro.utils import LoggerV2 + +from .train_ops import AbsTrainOps +from .trainer import AbsTrainer, MultiAgentTrainer + +if typing.TYPE_CHECKING: + from maro.rl.rl_component.rl_component_bundle import RLComponentBundle + + +class TrainOpsWorker(AbsWorker): + """Worker that executes methods defined in a subclass of ``AbsTrainOps`` and annotated as "remote" on demand. + + Args: + idx (int): Integer identifier for the worker. It is used to generate an internal ID, "worker.{idx}", + so that the proxy can keep track of its connection status. + rl_component_bundle (RLComponentBundle): The RL component bundle of the job. + producer_host (str): IP address of the proxy host to connect to. + producer_port (int, default=10001): Port of the proxy host to connect to. + """ + + def __init__( + self, + idx: int, + rl_component_bundle: RLComponentBundle, + producer_host: str, + producer_port: int = 10001, + logger: LoggerV2 = None, + ) -> None: + super(TrainOpsWorker, self).__init__( + idx=idx, producer_host=producer_host, producer_port=producer_port, logger=logger, + ) + + self._rl_component_bundle = rl_component_bundle + self._trainer_dict: Dict[str, AbsTrainer] = {} + + self._ops_dict: Dict[str, AbsTrainOps] = {} + + def _compute(self, msg: list) -> None: + """Execute a method defined by some train ops and annotated as "remote". + + Args: + msg (list): Multi-part message containing task specifications and parameters. + """ + if msg[-1] == b"EXIT": + self._stream.send(b"EXIT_ACK") + self.stop() + else: + ops_name, req = bytes_to_string(msg[0]), bytes_to_pyobj(msg[-1]) + assert isinstance(req, dict) + + if ops_name not in self._ops_dict: + trainer_name = ops_name.split(".")[0] + if trainer_name not in self._trainer_dict: + trainer = self._rl_component_bundle.trainer_creator[trainer_name]() + trainer.register_policy_creator( + self._rl_component_bundle.trainable_policy_creator, + self._rl_component_bundle.policy_trainer_mapping, + ) + self._trainer_dict[trainer_name] = trainer + + trainer = self._trainer_dict[trainer_name] + if isinstance(trainer, SingleAgentTrainer): + self._ops_dict[ops_name] = trainer.get_local_ops() + else: + assert isinstance(trainer, MultiAgentTrainer) + self._ops_dict[ops_name] = trainer.get_local_ops(ops_name) + self._logger.info(f"Created ops {ops_name} at {self._id}") + + self._ops_dict[ops_name].set_state(req["state"]) + func = getattr(self._ops_dict[ops_name], req["func"]) + result = func(*req["args"], **req["kwargs"]) + self._stream.send_multipart([msg[0], pyobj_to_bytes(result)]) diff --git a/maro/rl/utils/__init__.py b/maro/rl/utils/__init__.py index e274024a3..df0917dcd 100644 --- a/maro/rl/utils/__init__.py +++ b/maro/rl/utils/__init__.py @@ -1,12 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from .experience_collection import ExperienceCollectionUtils -from .trajectory_utils import get_k_step_returns, get_lambda_returns, get_truncated_cumulative_reward -from .value_utils import get_log_prob, get_max, get_td_errors, select_by_actions +from typing import Union + +from .objects import SHAPE_CHECK_FLAG +from .torch_utils import average_grads, get_torch_device, match_shape, ndarray_to_tensor +from .trajectory_computation import discount_cumsum +from .transition_batch import merge_transition_batches, MultiTransitionBatch, TransitionBatch + +AbsTransitionBatch = Union[TransitionBatch, MultiTransitionBatch] __all__ = [ - "ExperienceCollectionUtils", - "get_k_step_returns", "get_lambda_returns", "get_truncated_cumulative_reward", - "get_log_prob", "get_max", "get_td_errors", "select_by_actions", + "SHAPE_CHECK_FLAG", + "average_grads", "get_torch_device", "match_shape", "ndarray_to_tensor", + "discount_cumsum", + "AbsTransitionBatch", "MultiTransitionBatch", "TransitionBatch", "merge_transition_batches", ] diff --git a/maro/rl/utils/common.py b/maro/rl/utils/common.py new file mode 100644 index 000000000..e69b907b7 --- /dev/null +++ b/maro/rl/utils/common.py @@ -0,0 +1,86 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import pickle +import socket +from typing import List, Optional + + +def get_env(var_name: str, required: bool = True, default: object = None) -> str: + """Wrapper for os.getenv() that includes a check for mandatory environment variables. + + Args: + var_name (str): Variable name. + required (bool, default=True): Flag indicating whether the environment variable in questions is required. + If this is true and the environment variable is not present in ``os.environ``, a ``KeyError`` is raised. + default (object, default=None): Default value for the environment variable if it is missing in ``os.environ`` + and ``required`` is false. Ignored if ``required`` is True. + + Returns: + The environment variable. + """ + if var_name not in os.environ: + if required: + raise KeyError(f"Missing environment variable: {var_name}") + return default + + return os.getenv(var_name) + + +def int_or_none(val: Optional[str]) -> Optional[int]: + return int(val) if val is not None else None + + +def float_or_none(val: Optional[str]) -> Optional[float]: + return float(val) if val is not None else None + + +def list_or_none(vals_str: Optional[str]) -> List[int]: + return [int(val) for val in vals_str.split()] if vals_str is not None else [] + + +# serialization and deserialization for messaging +DEFAULT_MSG_ENCODING = "utf-8" + + +def string_to_bytes(s: str) -> bytes: + return s.encode(DEFAULT_MSG_ENCODING) + + +def bytes_to_string(bytes_: bytes) -> str: + return bytes_.decode(DEFAULT_MSG_ENCODING) + + +def pyobj_to_bytes(pyobj) -> bytes: + return pickle.dumps(pyobj) + + +def bytes_to_pyobj(bytes_: bytes) -> object: + return pickle.loads(bytes_) + + +def get_own_ip_address() -> str: + """https://stackoverflow.com/questions/166506/finding-local-ip-addresses-using-pythons-stdlib""" + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.settimeout(0) + try: + # doesn't even have to be reachable + sock.connect(("10.255.255.255", 1)) + ip = sock.getsockname()[0] + except Exception: + ip = "127.0.0.1" + finally: + sock.close() + return ip + + +def get_ip_address_by_hostname(host: str) -> str: + if host in ("localhost", "127.0.0.1"): + return get_own_ip_address() + + while True: + try: + return socket.gethostbyname(host) + except Exception: + continue diff --git a/maro/rl/utils/experience_collection.py b/maro/rl/utils/experience_collection.py deleted file mode 100644 index d190a5d67..000000000 --- a/maro/rl/utils/experience_collection.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from collections import defaultdict - - -class ExperienceCollectionUtils: - @staticmethod - def concat(exp, is_single_source: bool = False, is_single_agent: bool = False) -> dict: - """Concatenate experiences from multiple sources, by agent ID. - - The experience from each source is expected to be already grouped by agent ID. The result is a single dictionary - of experiences with keys being agent IDs and values being the concatenation of experiences from all sources - for each agent ID. - - Args: - exp: Experiences from one or more sources. - is_single_source (bool): If True, experiences are from a single (actor) source. Defaults to False. - is_single_agent (bool): If True, experiences are from a single agent. Defaults to False. - - Returns: - Concatenated experiences for each agent. - """ - if is_single_source: - return exp - - merged = defaultdict(list) if is_single_agent else defaultdict(lambda: defaultdict(list)) - for ex in exp.values(): - if is_single_agent: - for k, v in ex.items(): - merged[k].extend[v] - else: - for agent_id, e in ex.items(): - for k, v in e.items(): - merged[agent_id][k].extend(v) - - return merged - - @staticmethod - def stack(exp, is_single_source: bool = False, is_single_agent: bool = False) -> dict: - """Collect each agent's trajectories from multiple sources. - - Args: - exp: Experiences from one or more sources. - is_single_source (bool): If True, experiences are from a single (actor) source. Defaults to False. - is_single_agent (bool): If True, the experiences are from a single agent. Defaults to False. - - Returns: - A list of trajectories for each agent. - """ - if is_single_source: - return [exp] if is_single_agent else {agent_id: [ex] for agent_id, ex in exp.items()} - - if is_single_agent: - return list(exp.values()) - - ret = defaultdict(list) - for ex in exp.values(): - for agent_id, e in ex.items(): - ret[agent_id].append(e) - - return ret diff --git a/maro/rl/utils/message_enums.py b/maro/rl/utils/message_enums.py new file mode 100644 index 000000000..3da0bd74d --- /dev/null +++ b/maro/rl/utils/message_enums.py @@ -0,0 +1,51 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from enum import Enum + + +class MsgTag(Enum): + SAMPLE = "sample" + TEST = "test" + SAMPLE_DONE = "eval_done" + TEST_DONE = "collect_done" + INIT_POLICIES = "init_policies" + INIT_POLICIES_DONE = "init_policies_done" + POLICY_STATE = "policy_state" + CHOOSE_ACTION = "choose_action" + ACTION = "action" + GET_INITIAL_POLICY_STATE = "get_initial_policy_state" + LEARN = "learn" + LEARN_DONE = "learn_finished" + COMPUTE_GRAD = "compute_grad" + COMPUTE_GRAD_DONE = "compute_grad_done" + ABORT_ROLLOUT = "abort_rollout" + DONE = "done" + EXIT = "exit" + REQUEST_WORKER = "request_worker" + RELEASE_WORKER = "release_worker" + ASSIGN_WORKER = "assign_worker" + + +class MsgKey(Enum): + ACTION = "action" + AGENT_ID = "agent_id" + EPISODE = "episode" + SEGMENT = "segment" + NUM_STEPS = "num_steps" + STEP = "step" + POLICY_IDS = "policy_ids" + ROLLOUT_INFO = "rollout_info" + INTO = "info" + GRAD_TASK = "grad_task" + GRAD_SCOPE = "grad_scope" + LOSS_INFO = "loss_info" + STATE = "state" + TENSOR = "tensor" + POLICY_STATE = "policy_state" + EXPLORATION_STEP = "exploration_step" + VERSION = "version" + STEP_RANGE = "step_range" + END_OF_EPISODE = "end_of_episode" + WORKER_ID = "worker_id" + WORKER_ID_LIST = "worker_id_list" diff --git a/examples/cim/__init__.py b/maro/rl/utils/objects.py similarity index 61% rename from examples/cim/__init__.py rename to maro/rl/utils/objects.py index b14b47650..0a73c0f66 100644 --- a/examples/cim/__init__.py +++ b/maro/rl/utils/objects.py @@ -1,2 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +SHAPE_CHECK_FLAG = True +FILE_SUFFIX = "ckpt" diff --git a/maro/rl/utils/torch_utils.py b/maro/rl/utils/torch_utils.py new file mode 100644 index 000000000..914efc6e6 --- /dev/null +++ b/maro/rl/utils/torch_utils.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import List, Union + +import numpy as np +import torch + +from .objects import SHAPE_CHECK_FLAG + + +def match_shape(tensor: Union[torch.Tensor, np.ndarray], shape: tuple) -> bool: + """Check if a torch.Tensor / np.ndarray could match the expected shape. + + Args: + tensor (Union[torch.Tensor, np.ndarray]): Tensor. + shape (tuple): The expected shape tuple. If an element in this tuple is None, it means this dimension + could match any value (usually used for the `batch_size` dimension). + + Returns: + Whether the tensor could match the expected shape. + """ + if not SHAPE_CHECK_FLAG: + return True + else: + if len(tensor.shape) != len(shape): + return False + for val, expected in zip(tensor.shape, shape): + if expected is not None and expected != val: + return False + return True + + +def ndarray_to_tensor(array: np.ndarray, device: torch.device = None) -> torch.Tensor: + """ + Convert a np.ndarray to a torch.Tensor. + + Args: + array (np.ndarray): The input ndarray. + device (torch.device): The device to assign this tensor. + + Returns: + A tensor with same shape and values. + """ + return torch.from_numpy(array).to(device) + + +def average_grads(grad_list: List[dict]) -> dict: + """Obtain the average of a list of gradients. + """ + if len(grad_list) == 1: + return grad_list[0] + return { + param_name: torch.mean(torch.stack([grad[param_name] for grad in grad_list]), dim=0) + for param_name in grad_list[0] + } + + +def get_torch_device(device: str = None): + return torch.device(device if device else ("cuda" if torch.cuda.is_available() else "cpu")) diff --git a/maro/rl/utils/training.py b/maro/rl/utils/training.py new file mode 100644 index 000000000..ef208bec4 --- /dev/null +++ b/maro/rl/utils/training.py @@ -0,0 +1,6 @@ +import os + + +def get_latest_ep(path: str) -> int: + ep_list = [int(ep) for ep in os.listdir(path)] + return max(ep_list) diff --git a/maro/rl/utils/trajectory_computation.py b/maro/rl/utils/trajectory_computation.py new file mode 100644 index 000000000..dbfba4862 --- /dev/null +++ b/maro/rl/utils/trajectory_computation.py @@ -0,0 +1,27 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import Union + +import numpy as np +import scipy.signal + + +def discount_cumsum(x: Union[np.ndarray, list], discount: float) -> np.ndarray: + """ + Magic from rllab for computing discounted cumulative sums of vectors. + + Original code from: + https://github.com/rll/rllab/blob/master/rllab/misc/special.py). + + For details about the scipy function, see: + https://docs.scipy.org/doc/scipy/reference/tutorial/signal.html#difference-equation-filtering + + input: + vector x, + [x0, x1, x2] + + output: + [x0 + discount * x1 + discount^2 * x2, x1 + discount * x2, x2] + """ + return np.array(scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1], dtype=np.float32) diff --git a/maro/rl/utils/trajectory_utils.py b/maro/rl/utils/trajectory_utils.py deleted file mode 100644 index 3fe348adb..000000000 --- a/maro/rl/utils/trajectory_utils.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from functools import reduce -from typing import Union - -import numpy as np -import torch -import torch.nn.functional as F - - -def get_truncated_cumulative_reward( - rewards: Union[list, np.ndarray, torch.Tensor], - discount: float, - k: int = -1 -): - """Compute K-step cumulative rewards from a reward sequence. - Args: - rewards (Union[list, np.ndarray, torch.Tensor]): Reward sequence from a trajectory. - discount (float): Reward discount as in standard RL. - k (int): Number of steps in computing cumulative rewards. If it is -1, returns are computed using the - largest possible number of steps. Defaults to -1. - - Returns: - An ndarray or torch.Tensor instance containing the k-step cumulative rewards for each time step. - """ - if k < 0: - k = len(rewards) - 1 - pad = np.pad if isinstance(rewards, list) or isinstance(rewards, np.ndarray) else F.pad - return reduce( - lambda x, y: x * discount + y, - [pad(rewards[i:], (0, i)) for i in range(min(k, len(rewards)) - 1, -1, -1)] - ) - - -def get_k_step_returns( - rewards: Union[list, np.ndarray, torch.Tensor], - values: Union[list, np.ndarray, torch.Tensor], - discount: float, - k: int = -1 -): - """Compute K-step returns given reward and value sequences. - Args: - rewards (Union[list, np.ndarray, torch.Tensor]): Reward sequence from a trajectory. - values (Union[list, np.ndarray, torch.Tensor]): Sequence of values for the traversed states in a trajectory. - discount (float): Reward discount as in standard RL. - k (int): Number of steps in computing returns. If it is -1, returns are computed using the largest possible - number of steps. Defaults to -1. - - Returns: - An ndarray or torch.Tensor instance containing the k-step returns for each time step. - """ - assert len(rewards) == len(values), "rewards and values should have the same length" - assert len(values.shape) == 1, "values should be a one-dimensional array" - rewards[-1] = values[-1] - if k < 0: - k = len(rewards) - 1 - pad = np.pad if isinstance(rewards, list) or isinstance(rewards, np.ndarray) else F.pad - return reduce( - lambda x, y: x * discount + y, - [pad(rewards[i:], (0, i)) for i in range(min(k, len(rewards)) - 1, -1, -1)], - pad(values[k:], (0, k)) - ) - - -def get_lambda_returns( - rewards: Union[list, np.ndarray, torch.Tensor], - values: Union[list, np.ndarray, torch.Tensor], - discount: float, - lam: float, - k: int = -1 -): - """Compute lambda returns given reward and value sequences and a k. - Args: - rewards (Union[list, np.ndarray, torch.Tensor]): Reward sequence from a trajectory. - values (Union[list, np.ndarray, torch.Tensor]): Sequence of values for the traversed states in a trajectory. - discount (float): Reward discount as in standard RL. - lam (float): Lambda coefficient involved in computing lambda returns. - k (int): Number of steps where the lambda return series is truncated. If it is -1, no truncating is done and - the lambda return is carried out to the end of the sequence. Defaults to -1. - - Returns: - An ndarray or torch.Tensor instance containing the lambda returns for each time step. - """ - if k < 0: - k = len(rewards) - 1 - - # If lambda is zero, lambda return reduces to one-step return - if lam == .0: - return get_k_step_returns(rewards, values, discount, k=1) - - # If lambda is one, lambda return reduces to k-step return - if lam == 1.0: - return get_k_step_returns(rewards, values, discount, k=k) - - k = min(k, len(rewards) - 1) - pre_truncate = reduce( - lambda x, y: x * lam + y, - [get_k_step_returns(rewards, values, discount, k=k) for k in range(k - 1, 0, -1)] - ) - - post_truncate = get_k_step_returns(rewards, values, discount, k=k) * lam**(k - 1) - return (1 - lam) * pre_truncate + post_truncate diff --git a/maro/rl/utils/transition_batch.py b/maro/rl/utils/transition_batch.py new file mode 100644 index 000000000..26432d452 --- /dev/null +++ b/maro/rl/utils/transition_batch.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Optional + +import numpy as np + +from . import discount_cumsum +from .objects import SHAPE_CHECK_FLAG + + +@dataclass +class TransitionBatch: + states: np.ndarray # 2D + actions: np.ndarray # 2D + rewards: np.ndarray # 1D + next_states: np.ndarray # 2D + terminals: np.ndarray # 1D + returns: np.ndarray = None # 1D + advantages: np.ndarray = None # 1D + old_logps: np.ndarray = None # 1D + + @property + def size(self) -> int: + return self.states.shape[0] + + def __post_init__(self) -> None: + if SHAPE_CHECK_FLAG: + assert len(self.states.shape) == 2 and self.states.shape[0] > 0 + assert len(self.actions.shape) == 2 and self.actions.shape[0] == self.states.shape[0] + assert len(self.rewards.shape) == 1 and self.rewards.shape[0] == self.states.shape[0] + assert self.next_states.shape == self.states.shape + assert len(self.terminals.shape) == 1 and self.terminals.shape[0] == self.states.shape[0] + + def make_kth_sub_batch(self, i: int, k: int) -> TransitionBatch: + return TransitionBatch( + states=self.states[i::k], + actions=self.actions[i::k], + rewards=self.rewards[i::k], + next_states=self.next_states[i::k], + terminals=self.terminals[i::k], + returns=self.returns[i::k] if self.returns is not None else None, + advantages=self.advantages[i::k] if self.advantages is not None else None, + old_logps=self.old_logps[i::k] if self.old_logps is not None else None, + ) + + def split(self, k: int) -> List[TransitionBatch]: + return [self.make_kth_sub_batch(i, k) for i in range(k)] + + +@dataclass +class MultiTransitionBatch: + states: np.ndarray # 2D + actions: List[np.ndarray] # List of 2D + rewards: List[np.ndarray] # List of 1D + next_states: np.ndarray # 2D + agent_states: List[np.ndarray] # List of 2D + next_agent_states: List[np.ndarray] # List of 2D + terminals: np.ndarray # 1D + + returns: Optional[List[np.ndarray]] = None # List of 1D + advantages: Optional[List[np.ndarray]] = None # List of 1D + + @property + def size(self) -> int: + return self.states.shape[0] + + def __post_init__(self) -> None: + if SHAPE_CHECK_FLAG: + assert len(self.states.shape) == 2 and self.states.shape[0] > 0 + + assert len(self.actions) == len(self.rewards) + assert len(self.agent_states) == len(self.actions) + for i in range(len(self.actions)): + assert len(self.actions[i].shape) == 2 and self.actions[i].shape[0] == self.states.shape[0] + assert len(self.rewards[i].shape) == 1 and self.rewards[i].shape[0] == self.states.shape[0] + assert len(self.agent_states[i].shape) == 2 + assert self.agent_states[i].shape[0] == self.states.shape[0] + + assert len(self.terminals.shape) == 1 and self.terminals.shape[0] == self.states.shape[0] + assert self.next_states.shape == self.states.shape + + assert len(self.next_agent_states) == len(self.agent_states) + for i in range(len(self.next_agent_states)): + assert self.agent_states[i].shape == self.next_agent_states[i].shape + + def calc_returns(self, discount_factor: float) -> None: + self.returns = [discount_cumsum(reward, discount_factor) for reward in self.rewards] + + def make_kth_sub_batch(self, i: int, k: int) -> MultiTransitionBatch: + states = self.states[i::k] + actions = [action[i::k] for action in self.actions] + rewards = [reward[i::k] for reward in self.rewards] + next_states = self.next_states[i::k] + agent_states = [state[i::k] for state in self.agent_states] + next_agent_states = [state[i::k] for state in self.next_agent_states] + terminals = self.terminals[i::k] + returns = None if self.returns is None else [r[i::k] for r in self.returns] + advantages = None if self.advantages is None else [advantage[i::k] for advantage in self.advantages] + return MultiTransitionBatch( + states, actions, rewards, next_states, agent_states, + next_agent_states, terminals, returns, advantages, + ) + + def split(self, k: int) -> List[MultiTransitionBatch]: + return [self.make_kth_sub_batch(i, k) for i in range(k)] + + +def merge_transition_batches(batch_list: List[TransitionBatch]) -> TransitionBatch: + return TransitionBatch( + states=np.concatenate([batch.states for batch in batch_list], axis=0), + actions=np.concatenate([batch.actions for batch in batch_list], axis=0), + rewards=np.concatenate([batch.rewards for batch in batch_list], axis=0), + next_states=np.concatenate([batch.next_states for batch in batch_list], axis=0), + terminals=np.concatenate([batch.terminals for batch in batch_list]), + returns=np.concatenate([batch.returns for batch in batch_list]), + advantages=np.concatenate([batch.advantages for batch in batch_list]), + old_logps=None if batch_list[0].old_logps is None else np.concatenate( + [batch.old_logps for batch in batch_list] + ), + ) diff --git a/maro/rl/utils/value_utils.py b/maro/rl/utils/value_utils.py deleted file mode 100644 index 99de02b6e..000000000 --- a/maro/rl/utils/value_utils.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from typing import Callable - -import torch - - -def select_by_actions(q_values: torch.Tensor, actions: torch.Tensor): - if len(actions.shape) == 1: - actions = actions.unsqueeze(1) # (N, 1) - return q_values.gather(1, actions).squeeze(1) - - -def get_max(q_values: torch.Tensor, expand_action_dim: bool = True): - """ - Given Q-values for a batch of states and all actions, return the maximum Q-value and - the corresponding action index for each state. - """ - greedy_q, actions = q_values.max(dim=1) - if expand_action_dim: - actions = actions.unsqueeze(1) - return greedy_q, actions - - -def get_td_errors( - q_values: torch.Tensor, next_q_values: torch.Tensor, rewards: torch.Tensor, gamma: float, - loss_func: Callable -): - target_q_values = (rewards + gamma * next_q_values).detach() # (N,) - return loss_func(q_values, target_q_values) - - -def get_log_prob(action_probs: torch.Tensor, actions: torch.Tensor): - return torch.log(action_probs.gather(1, actions.unsqueeze(1)).squeeze()) # (N,) diff --git a/maro/cli/process/utils/__init__.py b/maro/rl/workflows/__init__.py similarity index 100% rename from maro/cli/process/utils/__init__.py rename to maro/rl/workflows/__init__.py diff --git a/examples/cim/dqn/__init__.py b/maro/rl/workflows/config/__init__.py similarity index 51% rename from examples/cim/dqn/__init__.py rename to maro/rl/workflows/config/__init__.py index b14b47650..59348d2b7 100644 --- a/examples/cim/dqn/__init__.py +++ b/maro/rl/workflows/config/__init__.py @@ -1,2 +1,8 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from .parser import ConfigParser + +__all__ = [ + "ConfigParser", +] diff --git a/maro/rl/workflows/config/parser.py b/maro/rl/workflows/config/parser.py new file mode 100644 index 000000000..80e7b6971 --- /dev/null +++ b/maro/rl/workflows/config/parser.py @@ -0,0 +1,396 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import ipaddress +import os +from typing import Dict, Tuple, Union + +import yaml + +from maro.utils.logger import LEVEL_MAP + + +class ConfigParser: + """Configuration parser for running RL workflows. + + Args: + config (Union[str, dict]): A dictionary configuration or a path to a Yaml file that contains + the configuration. If it is a path, the parser will attempt to read it into a dictionary + in memory. + """ + + def __init__(self, config: Union[str, dict]) -> None: + assert isinstance(config, (dict, str)) + if isinstance(config, str): + with open(config, "r") as fp: + self._config = yaml.safe_load(fp) + else: + self._config = config + + self._validation_err_pfx = f"Invalid configuration: {self._config}" + self._validate() + + @property + def config(self) -> dict: + return self._config + + def _validate(self) -> None: + if "job" not in self._config: + raise KeyError(f"{self._validation_err_pfx}: missing field 'job'") + if "scenario_path" not in self._config: + raise KeyError(f"{self._validation_err_pfx}: missing field 'scenario_path'") + if "log_path" not in self._config: + raise KeyError(f"{self._validation_err_pfx}: missing field 'log_path'") + + self._validate_main_section() + self._validate_rollout_section() + self._validate_training_section() + + def _validate_main_section(self) -> None: + if "main" not in self._config: + raise KeyError(f"{self._validation_err_pfx}: missing field 'main'") + + if "num_episodes" not in self._config["main"]: + raise KeyError(f"{self._validation_err_pfx}: missing field 'num_episodes' under section 'main'") + + num_episodes = self._config["main"]["num_episodes"] + if not isinstance(num_episodes, int) or num_episodes < 1: + raise ValueError(f"{self._validation_err_pfx}: 'main.num_episodes' must be a positive int") + + num_steps = self._config["main"].get("num_steps", None) + if num_steps is not None: + if not isinstance(num_steps, int) or num_steps <= 0: + raise ValueError(f"{self._validation_err_pfx}: 'main.num_steps' must be a positive int") + + eval_schedule = self._config["main"].get("eval_schedule", None) + if eval_schedule is not None: + if ( + not isinstance(eval_schedule, (int, list)) or + isinstance(eval_schedule, int) and eval_schedule < 1 or + isinstance(eval_schedule, list) and any(not isinstance(val, int) or val < 1 for val in eval_schedule) + ): + raise ValueError( + f"{self._validation_err_pfx}: 'main.eval_schedule' must be a single positive int or a list of " + f"positive ints" + ) + + if "logging" in self._config["main"]: + self._validate_logging_section("main", self._config["main"]["logging"]) + + def _validate_rollout_section(self) -> None: + if "rollout" not in self._config or not isinstance(self._config["rollout"], dict): + raise KeyError(f"{self._validation_err_pfx}: missing section 'rollout'") + + # validate parallel rollout config + if "parallelism" in self._config["rollout"]: + conf = self._config["rollout"]["parallelism"] + if "sampling" not in conf: + raise KeyError( + f"{self._validation_err_pfx}: missing field 'sampling' under section 'rollout.parallelism'" + ) + + train_prl = conf["sampling"] + eval_prl = 1 if "eval" not in conf or conf["eval"] is None else conf["eval"] + if not isinstance(train_prl, int) or train_prl <= 0: + raise TypeError(f"{self._validation_err_pfx}: 'rollout.parallelism.sampling' must be a positive int") + if not isinstance(eval_prl, int) or eval_prl <= 0: + raise TypeError(f"{self._validation_err_pfx}: 'rollout.parallelism.eval' must be a positive int") + if max(train_prl, eval_prl) > 1: + if "controller" not in conf: + raise KeyError( + f"{self._validation_err_pfx}: missing field 'controller' under section 'rollout.parallelism'" + ) + self._validate_rollout_controller_section(conf["controller"]) + + # validate optional fields: min_env_samples, grace_factor + min_env_samples = conf.get("min_env_samples", None) + if min_env_samples is not None: + if not isinstance(min_env_samples, int) or min_env_samples > train_prl: + raise ValueError( + f"{self._validation_err_pfx}: 'rollout.parallelism.min_env_samples' must be an integer " + f"that does not exceed the value of 'rollout.parallelism.sampling': {train_prl}" + ) + + grace_factor = conf.get("grace_factor", None) + if grace_factor is not None and not isinstance(grace_factor, (int, float)): + raise ValueError( + f"{self._validation_err_pfx}: 'rollout.parallelism.grace_factor' must be an int or float" + ) + + if "logging" in self._config["rollout"]: + self._validate_logging_section("rollout", self._config["rollout"]["logging"]) + + def _validate_rollout_controller_section(self, conf: dict) -> None: + if "host" not in conf: + raise KeyError( + f"{self._validation_err_pfx}: missing field 'host' under section 'rollout.parallelism.controller'" + ) + if not isinstance(conf["host"], str): + raise TypeError(f"{self._validation_err_pfx}: 'rollout.parallelism.controller.host' must be a string") + + # Check that the host string is a valid IP address + try: + ipaddress.ip_address(conf["host"]) + except ValueError: + raise ValueError( + f"{self._validation_err_pfx}: 'rollout.parallelism.controller.host' is not a valid IP address" + ) + + if "port" not in conf: + raise KeyError( + f"{self._validation_err_pfx}: missing field 'port' under section 'rollout.parallelism.controller'" + ) + if not isinstance(conf["port"], int): + raise TypeError(f"{self._validation_err_pfx}: 'rollout.parallelism.controller.port' must be an int") + + def _validate_training_section(self) -> None: + if "training" not in self._config or not isinstance(self._config["training"], dict): + raise KeyError(f"{self._validation_err_pfx}: missing field 'training'") + if "mode" not in self._config["training"]: + raise KeyError(f"{self._validation_err_pfx}: missing field 'mode' under section 'training'") + if self._config["training"]["mode"] not in {"simple", "parallel"}: + raise ValueError( + f"'mode' value under section 'training' must be 'simple' or 'parallel', got {self._config['mode']}" + ) + + if self._config["training"]["mode"] == "parallel": + if "num_workers" not in self._config["training"]: + raise KeyError(f"{self._validation_err_pfx}: missing field 'num_workers' under section 'training'") + if "proxy" not in self._config["training"]: + raise KeyError(f"{self._validation_err_pfx}: missing field 'proxy' under section 'training'") + self._validate_train_proxy_section(self._config["training"]["proxy"]) + if "logging" in self._config["training"]: + self._validate_logging_section("training", self._config["training"]["logging"]) + + load_path = self._config["training"].get("load_path", None) + if load_path is not None and not isinstance(load_path, str): + raise TypeError(f"{self._validation_err_pfx}: 'training.load_path' must be a string") + load_episode = self._config["training"].get("load_episode", None) + if load_episode is not None and not isinstance(load_episode, int): + raise TypeError(f"{self._validation_err_pfx}: 'training.load_episode' must be a integer") + + if "checkpointing" in self._config["training"]: + self._validate_checkpointing_section(self._config["training"]["checkpointing"]) + + def _validate_train_proxy_section(self, proxy_section: dict) -> None: + if "host" not in proxy_section: + raise KeyError(f"{self._validation_err_pfx}: missing field 'host' under section 'proxy'") + if not isinstance(proxy_section["host"], str): + raise TypeError(f"{self._validation_err_pfx}: 'training.proxy.host' must be a string") + # Check that the host string is a valid IP address + try: + ipaddress.ip_address(proxy_section["host"]) + except ValueError: + raise ValueError(f"{self._validation_err_pfx}: 'training.proxy.host' is not a valid IP address") + + if "frontend" not in proxy_section: + raise KeyError(f"{self._validation_err_pfx}: missing field 'frontend' under section 'proxy'") + if not isinstance(proxy_section["frontend"], int): + raise TypeError(f"{self._validation_err_pfx}: 'training.proxy.frontend' must be an int") + + if "backend" not in proxy_section: + raise KeyError(f"{self._validation_err_pfx}: missing field 'backend' under section 'proxy'") + if not isinstance(proxy_section["backend"], int): + raise TypeError(f"{self._validation_err_pfx}: 'training.proxy.backend' must be an int") + + def _validate_checkpointing_section(self, section: dict) -> None: + if "path" not in section: + raise KeyError(f"{self._validation_err_pfx}: missing field 'path' under section 'checkpointing'") + if not isinstance(section["path"], str): + raise TypeError(f"{self._validation_err_pfx}: 'training.checkpointing.path' must be a string") + + if "interval" in section: + if not isinstance(section["interval"], int): + raise TypeError( + f"{self._validation_err_pfx}: 'training.checkpointing.interval' must be an int" + ) + + def _validate_logging_section(self, component, level_dict: dict) -> None: + if any(key not in {"stdout", "file"} for key in level_dict): + raise KeyError( + f"{self._validation_err_pfx}: fields under section '{component}.logging' must be 'stdout' or 'file'" + ) + valid_log_levels = set(LEVEL_MAP.keys()) + for key, val in level_dict.items(): + if val not in valid_log_levels: + raise ValueError( + f"{self._validation_err_pfx}: '{component}.logging.{key}' must be one of {valid_log_levels}." + ) + + def get_path_mapping(self, containerize: bool = False) -> dict: + """Generate path mappings for a local or containerized environment. + + Args: + containerize (bool): If true, the paths you specify in the configuration file (which should always be local) + are mapped to paths inside the containers as follows: + local/scenario/path -> "/scenario" + local/load/path -> "/loadpoint" + local/checkpoint/path -> "/checkpoints" + local/log/path -> "/logs" + Defaults to False. + """ + log_dir = os.path.dirname(self._config["log_path"]) + path_map = { + self._config["scenario_path"]: "/scenario" if containerize else self._config["scenario_path"], + log_dir: "/logs" if containerize else log_dir + } + + load_path = self._config["training"].get("load_path", None) + if load_path is not None: + path_map[load_path] = "/loadpoint" if containerize else load_path + if "checkpointing" in self._config["training"]: + ckpt_path = self._config["training"]["checkpointing"]["path"] + path_map[ckpt_path] = "/checkpoints" if containerize else ckpt_path + + return path_map + + def get_job_spec(self, containerize: bool = False) -> Dict[str, Tuple[str, Dict[str, str]]]: + """Generate environment variables for the workflow scripts. + + A doubly-nested dictionary is returned that contains the environment variables for each distributed component. + + Args: + containerize (bool): If true, the generated environment variables are to be used in a containerized + environment. Only path-related environment variables are affected by this flag. See the docstring + for ``get_path_mappings`` for details. Defaults to False. + """ + path_mapping = self.get_path_mapping(containerize=containerize) + scenario_path = path_mapping[self._config["scenario_path"]] + num_episodes = self._config["main"]["num_episodes"] + main_proc = f"{self._config['job']}.main" + min_n_sample = self._config["main"].get("min_n_sample", 1) + env = { + main_proc: ( + os.path.join(self._get_workflow_path(containerize=containerize), "main.py"), + { + "JOB": self._config["job"], + "NUM_EPISODES": str(num_episodes), + "MIN_N_SAMPLE": str(min_n_sample), + "TRAIN_MODE": self._config["training"]["mode"], + "SCENARIO_PATH": scenario_path, + } + ) + } + + main_proc_env = env[main_proc][1] + if "eval_schedule" in self._config["main"]: + # If it is an int, it is treated as the number of episodes between two adjacent evaluations. For example, + # if the total number of episodes is 20 and this is 5, an evaluation schedule of [5, 10, 15, 20] + # (start from 1) will be generated for the environment variable (as a string). If it is a list, the sorted + # version of the list will be generated for the environment variable (as a string). + sch = self._config["main"]["eval_schedule"] + if isinstance(sch, int): + main_proc_env["EVAL_SCHEDULE"] = " ".join([str(sch * i) for i in range(1, num_episodes // sch + 1)]) + else: + main_proc_env["EVAL_SCHEDULE"] = " ".join([str(val) for val in sorted(sch)]) + + load_path = self._config["training"].get("load_path", None) + if load_path is not None: + env["main"]["LOAD_PATH"] = path_mapping[load_path] + load_episode = self._config["training"].get("load_episode", None) + if load_episode is not None: + env["main"]["LOAD_EPISODE"] = str(load_episode) + + if "checkpointing" in self._config["training"]: + conf = self._config["training"]["checkpointing"] + main_proc_env["CHECKPOINT_PATH"] = path_mapping[conf["path"]] + if "interval" in conf: + main_proc_env["CHECKPOINT_INTERVAL"] = str(conf["interval"]) + + num_steps = self._config["main"].get("num_steps", None) + if num_steps is not None: + main_proc_env["NUM_STEPS"] = str(num_steps) + + if "logging" in self._config["main"]: + main_proc_env.update({ + "LOG_LEVEL_STDOUT": self.config["main"]["logging"]["stdout"], + "LOG_LEVEL_FILE": self.config["main"]["logging"]["file"], + }) + + if "parallelism" in self._config["rollout"]: + conf = self._config["rollout"]["parallelism"] + env_sampling_parallelism = conf["sampling"] + env_eval_parallelism = 1 if "eval" not in conf or conf["eval"] is None else conf["eval"] + else: + env_sampling_parallelism = env_eval_parallelism = 1 + rollout_parallelism = max(env_sampling_parallelism, env_eval_parallelism) + if rollout_parallelism > 1: + conf = self._config["rollout"]["parallelism"] + rollout_controller_port = str(conf["controller"]["port"]) + main_proc_env["ENV_SAMPLE_PARALLELISM"] = str(env_sampling_parallelism) + main_proc_env["ENV_EVAL_PARALLELISM"] = str(env_eval_parallelism) + main_proc_env["ROLLOUT_CONTROLLER_PORT"] = rollout_controller_port + # optional settings for parallel rollout + if "min_env_samples" in self._config["rollout"]: + main_proc_env["MIN_ENV_SAMPLES"] = str(conf["min_env_samples"]) + if "grace_factor" in self._config["rollout"]: + main_proc_env["GRACE_FACTOR"] = str(conf["grace_factor"]) + + for i in range(rollout_parallelism): + worker_id = f"{self._config['job']}.rollout_worker-{i}" + env[worker_id] = ( + os.path.join(self._get_workflow_path(containerize=containerize), "rollout_worker.py"), + { + "ID": str(i), + "ROLLOUT_CONTROLLER_HOST": self._get_rollout_controller_host(containerize=containerize), + "ROLLOUT_CONTROLLER_PORT": rollout_controller_port, + "SCENARIO_PATH": scenario_path, + } + ) + if "logging" in self._config["rollout"]: + env[worker_id][1].update({ + "LOG_LEVEL_STDOUT": self.config["rollout"]["logging"]["stdout"], + "LOG_LEVEL_FILE": self.config["rollout"]["logging"]["file"], + }) + + if self._config["training"]["mode"] == "parallel": + conf = self._config['training']['proxy'] + producer_host = self._get_train_proxy_host(containerize=containerize) + proxy_frontend_port = str(conf["frontend"]) + proxy_backend_port = str(conf["backend"]) + num_workers = self._config["training"]["num_workers"] + env[main_proc][1].update({ + "TRAIN_PROXY_HOST": producer_host, "TRAIN_PROXY_FRONTEND_PORT": proxy_frontend_port, + }) + env[f"{self._config['job']}.train_proxy"] = ( + os.path.join(self._get_workflow_path(containerize=containerize), "train_proxy.py"), + {"TRAIN_PROXY_FRONTEND_PORT": proxy_frontend_port, "TRAIN_PROXY_BACKEND_PORT": proxy_backend_port} + ) + for i in range(num_workers): + worker_id = f"{self._config['job']}.train_worker-{i}" + env[worker_id] = ( + os.path.join(self._get_workflow_path(containerize=containerize), "train_worker.py"), + { + "ID": str(i), + "TRAIN_PROXY_HOST": producer_host, + "TRAIN_PROXY_BACKEND_PORT": proxy_backend_port, + "SCENARIO_PATH": scenario_path, + } + ) + if "logging" in self._config["training"]: + env[worker_id][1].update({ + "LOG_LEVEL_STDOUT": self.config["training"]["logging"]["stdout"], + "LOG_LEVEL_FILE": self.config["training"]["logging"]["file"], + }) + + # All components write logs to the same file + log_dir, log_file = os.path.split(self._config["log_path"]) + for _, vars in env.values(): + vars["LOG_PATH"] = os.path.join(path_mapping[log_dir], log_file) + + return env + + def _get_workflow_path(self, containerize: bool = False) -> str: + if containerize: + return "/maro/maro/rl/workflows" + else: + return os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + + def _get_rollout_controller_host(self, containerize: bool = False) -> str: + if containerize: + return f"{self._config['job']}.main" + else: + return self._config["rollout"]["parallelism"]["controller"]["host"] + + def _get_train_proxy_host(self, containerize: bool = False) -> str: + return f"{self._config['job']}.train_proxy" if containerize else self._config["training"]["proxy"]["host"] diff --git a/maro/rl/workflows/config/template.yml b/maro/rl/workflows/config/template.yml new file mode 100644 index 000000000..3464e9edc --- /dev/null +++ b/maro/rl/workflows/config/template.yml @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# This is a configuration template for running reinforcement learning workflows with MARO's CLI tools. The workflows +# are scenario agnostic, meaning that this template can be applied to any scenario as long as the necessary components +# are provided (see examples/rl/README.md for details about these components). Your scenario should be placed in a +# folder and its path should be specified in the "scenario_path" field. Note that all fields with a "null" value are +# optional and will be converted to None by the parser unless a non-null value is specified. Note that commenting them +# out or leaving them blank are equivalent to using "null". + + +job: your_job_name +# Path to a directory that defines a business scenario and contains the necessary components to execute reinforcement +# learning workflows in single-threaded, multi-process and distributed modes. +scenario_path: "/path/to/your/scenario" +log_path: "/path/to/your/log/folder" # All logs are written to a single file for ease of viewing. +main: + num_episodes: 100 # Number of episodes to run. Each episode is one cycle of roll-out and training. + # Number of environment steps to collect environment samples over. If null, samples are collected until the + # environments reach the terminal state, i.e., for a full episode. Otherwise, samples are collected until the + # specified number of steps or the terminal state is reached, whichever comes first. + num_steps: null + # This can be an integer or a list of integers. An integer indicates the interval at which policies are evaluated. + # A list indicates the episodes at the end of which policies are to be evaluated. Note that episode indexes are + # 1-based. + eval_schedule: 10 + # Minimum number of samples to start training in one epoch. The workflow will re-run experience collection + # until we have at least `min_n_sample` of experiences. + min_n_sample: 1 + logging: # log levels for the main loop + stdout: INFO # DEBUG, INFO, WARN, ERROR, CRITICAL, PROGRESS + file: DEBUG # DEBUG, INFO, WARN, ERROR, CRITICAL, PROGRESS +rollout: + # Optional section to specify roll-out parallelism settings. If absent, a single environment instance will be created + # locally for training and evaluation. + parallelism: + sampling: 10 # Number of parallel roll-outs to collecting training data from. + # Number of parallel roll-outs to evaluate policies on. If not specified, one roll-out worker is chosen to perform + # evaluation. + eval: null + # Minimum number of environment samples to collect from the parallel roll-outs per episode / segment before moving + # on to the training phase. The actual number of env samples collected may be more than this value if we allow a + # grace period (see the comment for rollout.parallelism.grace_factor for details), but never less. This value should + # not exceed rollout.parallelism.sampling. + min_env_samples: 8 + # Factor that determines the additional wait time after the required number of environment samples as indicated by + # "min_env_samples" are received. For example, if T seconds elapsed after receiving "min_env_samples" environment + # samples, it will wait an additional T * grace_factor seconds to try to collect the remaining results. + grace_factor: 0.2 + controller: # Parallel roll-out controller settings. Ignored if rollout.parallelism section is absent. + host: "127.0.0.1" # Controller's IP address. Ignored if run in containerized environments. + port: 20000 # Controller's network port for remote roll-out workers to connect to. + logging: # log levels for roll-out workers + stdout: INFO # DEBUG, INFO, WARN, ERROR, CRITICAL, PROGRESS + file: DEBUG # DEBUG, INFO, WARN, ERROR, CRITICAL, PROGRESS +training: + # Must be "simple" or "parallel". In simple mode, all underlying models are trained locally. In parallel mode, + # all trainers send gradient-related tasks to a proxy service where they get dispatched to a set of workers. + mode: simple + # Path to load previously saved trainer snapshots from. A policy trainer's snapshot includes the states of all + # the policies it manages as well as the states of auxillary models (e.g., critics in the Actor-Critic paradigm). + # If the path corresponds to an existing directory, the program will look under the directory for snapshot files + # that match the trainer names specified in the scenario and attempt to load from them. + load_path: "/path/to/your/models" # or `null` + # Which episode of the previously saved snapshots to load. If it is not provided, the last snapshot will be loaded. + load_episode: null + # Optional section to specify model checkpointing settings. + checkpointing: + # Directory to save trainer snapshots under. Snapshot files created at different episodes will be saved under + # separate folders named using episode numbers. For example, if a snapshot is created for a trainer named "dqn" + # at the end of episode 10, the file path would be "/path/to/your/checkpoint/folder/10/dqn.ckpt". + path: "/path/to/your/checkpoint/folder" + interval: 10 # Interval at which trained policies / models are persisted to disk. + proxy: # Proxy settings. Ignored if training.mode is "simple". + host: "127.0.0.1" # Proxy service host's IP address. Ignored if run in containerized environments. + frontend: 10000 # Proxy service's network port for trainers to send tasks to. + backend: 10001 # Proxy service's network port for remote workers to connect to. + num_workers: 10 # Number of workers to execute trainers' tasks. + logging: # log levels for training task workers + stdout: INFO # DEBUG, INFO, WARN, ERROR, CRITICAL, PROGRESS + file: DEBUG # DEBUG, INFO, WARN, ERROR, CRITICAL, PROGRESS diff --git a/maro/rl/workflows/main.py b/maro/rl/workflows/main.py new file mode 100644 index 000000000..f0eebc01b --- /dev/null +++ b/maro/rl/workflows/main.py @@ -0,0 +1,204 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import argparse +import importlib +import os +import sys +import time +from typing import List, Type + +from maro.rl.rl_component.rl_component_bundle import RLComponentBundle +from maro.rl.rollout import BatchEnvSampler, ExpElement +from maro.rl.training import TrainingManager +from maro.rl.utils import get_torch_device +from maro.rl.utils.common import float_or_none, get_env, int_or_none, list_or_none +from maro.rl.utils.training import get_latest_ep +from maro.utils import LoggerV2 + + +def get_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="MARO RL workflow parser") + parser.add_argument("--evaluate_only", action="store_true", help="Only run evaluation part of the workflow") + return parser.parse_args() + + +def main(rl_component_bundle: RLComponentBundle, args: argparse.Namespace) -> None: + if args.evaluate_only: + evaluate_only_workflow(rl_component_bundle) + else: + training_workflow(rl_component_bundle) + + +def training_workflow(rl_component_bundle: RLComponentBundle) -> None: + num_episodes = int(get_env("NUM_EPISODES")) + num_steps = int_or_none(get_env("NUM_STEPS", required=False)) + min_n_sample = int_or_none(get_env("MIN_N_SAMPLE")) + + logger = LoggerV2( + "MAIN", + dump_path=get_env("LOG_PATH"), + dump_mode="a", + stdout_level=get_env("LOG_LEVEL_STDOUT", required=False, default="CRITICAL"), + file_level=get_env("LOG_LEVEL_FILE", required=False, default="CRITICAL"), + ) + logger.info("Start training workflow.") + + env_sampling_parallelism = int_or_none(get_env("ENV_SAMPLE_PARALLELISM", required=False)) + env_eval_parallelism = int_or_none(get_env("ENV_EVAL_PARALLELISM", required=False)) + parallel_rollout = env_sampling_parallelism is not None or env_eval_parallelism is not None + train_mode = get_env("TRAIN_MODE") + + is_single_thread = train_mode == "simple" and not parallel_rollout + if is_single_thread: + rl_component_bundle.pre_create_policy_instances() + + if parallel_rollout: + env_sampler = BatchEnvSampler( + sampling_parallelism=env_sampling_parallelism, + port=int(get_env("ROLLOUT_CONTROLLER_PORT")), + min_env_samples=int_or_none(get_env("MIN_ENV_SAMPLES", required=False)), + grace_factor=float_or_none(get_env("GRACE_FACTOR", required=False)), + eval_parallelism=env_eval_parallelism, + logger=logger, + ) + else: + env_sampler = rl_component_bundle.env_sampler + if train_mode != "simple": + for policy_name, device_name in rl_component_bundle.device_mapping.items(): + env_sampler.assign_policy_to_device(policy_name, get_torch_device(device_name)) + + # evaluation schedule + eval_schedule = list_or_none(get_env("EVAL_SCHEDULE", required=False)) + logger.info(f"Policy will be evaluated at the end of episodes {eval_schedule}") + eval_point_index = 0 + + training_manager = TrainingManager( + rl_component_bundle=rl_component_bundle, + explicit_assign_device=(train_mode == "simple"), + proxy_address=None if train_mode == "simple" else ( + get_env("TRAIN_PROXY_HOST"), int(get_env("TRAIN_PROXY_FRONTEND_PORT")) + ), + logger=logger, + ) + + load_path = get_env("LOAD_PATH", required=False) + load_episode = int_or_none(get_env("LOAD_EPISODE", required=False)) + if load_path: + assert isinstance(load_path, str) + + ep = load_episode if load_episode is not None else get_latest_ep(load_path) + path = os.path.join(load_path, str(ep)) + + loaded = env_sampler.load_policy_state(path) + logger.info(f"Loaded policies {loaded} into env sampler from {path}") + + loaded = training_manager.load(path) + logger.info(f"Loaded trainers {loaded} from {path}") + start_ep = ep + 1 + else: + start_ep = 1 + + checkpoint_path = get_env("CHECKPOINT_PATH", required=False) + checkpoint_interval = int_or_none(get_env("CHECKPOINT_INTERVAL", required=False)) + + # main loop + for ep in range(start_ep, num_episodes + 1): + collect_time = training_time = 0 + total_experiences: List[List[ExpElement]] = [] + total_info_list: List[dict] = [] + n_sample = 0 + while n_sample < min_n_sample: + tc0 = time.time() + result = env_sampler.sample( + policy_state=training_manager.get_policy_state() if not is_single_thread else None, + num_steps=num_steps, + ) + experiences: List[List[ExpElement]] = result["experiences"] + info_list: List[dict] = result["info"] + + n_sample += len(experiences[0]) + total_experiences.extend(experiences) + total_info_list.extend(info_list) + + collect_time += time.time() - tc0 + + env_sampler.post_collect(total_info_list, ep) + + logger.info(f"Roll-out completed for episode {ep}. Training started...") + tu0 = time.time() + training_manager.record_experiences(total_experiences) + training_manager.train_step() + if checkpoint_path and (checkpoint_interval is None or ep % checkpoint_interval == 0): + assert isinstance(checkpoint_path, str) + pth = os.path.join(checkpoint_path, str(ep)) + training_manager.save(pth) + logger.info(f"All trainer states saved under {pth}") + training_time += time.time() - tu0 + + # performance details + logger.info(f"ep {ep} - roll-out time: {collect_time:.2f} seconds, training time: {training_time:.2f} seconds") + if eval_schedule and ep == eval_schedule[eval_point_index]: + eval_point_index += 1 + result = env_sampler.eval( + policy_state=training_manager.get_policy_state() if not is_single_thread else None + ) + env_sampler.post_evaluate(result["info"], ep) + + if isinstance(env_sampler, BatchEnvSampler): + env_sampler.exit() + training_manager.exit() + + +def evaluate_only_workflow(rl_component_bundle: RLComponentBundle) -> None: + logger = LoggerV2( + "MAIN", + dump_path=get_env("LOG_PATH"), + dump_mode="a", + stdout_level=get_env("LOG_LEVEL_STDOUT", required=False, default="CRITICAL"), + file_level=get_env("LOG_LEVEL_FILE", required=False, default="CRITICAL"), + ) + logger.info("Start evaluate only workflow.") + + env_sampling_parallelism = int_or_none(get_env("ENV_SAMPLE_PARALLELISM", required=False)) + env_eval_parallelism = int_or_none(get_env("ENV_EVAL_PARALLELISM", required=False)) + parallel_rollout = env_sampling_parallelism is not None or env_eval_parallelism is not None + + if parallel_rollout: + env_sampler = BatchEnvSampler( + sampling_parallelism=env_sampling_parallelism, + port=int(get_env("ROLLOUT_CONTROLLER_PORT")), + min_env_samples=int_or_none(get_env("MIN_ENV_SAMPLES", required=False)), + grace_factor=float_or_none(get_env("GRACE_FACTOR", required=False)), + eval_parallelism=env_eval_parallelism, + logger=logger, + ) + else: + env_sampler = rl_component_bundle.env_sampler + + load_path = get_env("LOAD_PATH", required=False) + load_episode = int_or_none(get_env("LOAD_EPISODE", required=False)) + if load_path: + assert isinstance(load_path, str) + + ep = load_episode if load_episode is not None else get_latest_ep(load_path) + path = os.path.join(load_path, str(ep)) + + loaded = env_sampler.load_policy_state(path) + logger.info(f"Loaded policies {loaded} into env sampler from {path}") + + result = env_sampler.eval() + env_sampler.post_evaluate(result["info"], -1) + + if isinstance(env_sampler, BatchEnvSampler): + env_sampler.exit() + + +if __name__ == "__main__": + scenario_path = get_env("SCENARIO_PATH") + scenario_path = os.path.normpath(scenario_path) + sys.path.insert(0, os.path.dirname(scenario_path)) + module = importlib.import_module(os.path.basename(scenario_path)) + + rl_component_bundle_cls: Type[RLComponentBundle] = getattr(module, "rl_component_bundle_cls") + rl_component_bundle = rl_component_bundle_cls() + main(rl_component_bundle, args=get_args()) diff --git a/maro/rl/workflows/rollout_worker.py b/maro/rl/workflows/rollout_worker.py new file mode 100644 index 000000000..9e5ea24ce --- /dev/null +++ b/maro/rl/workflows/rollout_worker.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import importlib +import os +import sys +from typing import Type + +from maro.rl.rl_component.rl_component_bundle import RLComponentBundle +from maro.rl.rollout import RolloutWorker +from maro.rl.utils.common import get_env, int_or_none +from maro.utils import LoggerV2 + +if __name__ == "__main__": + scenario_path = get_env("SCENARIO_PATH") + scenario_path = os.path.normpath(scenario_path) + sys.path.insert(0, os.path.dirname(scenario_path)) + module = importlib.import_module(os.path.basename(scenario_path)) + + rl_component_bundle_cls: Type[RLComponentBundle] = getattr(module, "rl_component_bundle_cls") + rl_component_bundle = rl_component_bundle_cls() + + worker_idx = int_or_none(get_env("ID")) + logger = LoggerV2( + f"ROLLOUT-WORKER.{worker_idx}", + dump_path=get_env("LOG_PATH"), + dump_mode="a", + stdout_level=get_env("LOG_LEVEL_STDOUT", required=False, default="CRITICAL"), + file_level=get_env("LOG_LEVEL_FILE", required=False, default="CRITICAL"), + ) + worker = RolloutWorker( + idx=worker_idx, + rl_component_bundle=rl_component_bundle, + producer_host=get_env("ROLLOUT_CONTROLLER_HOST"), + producer_port=int_or_none(get_env("ROLLOUT_CONTROLLER_PORT")), + logger=logger, + ) + worker.start() diff --git a/maro/rl/workflows/train_proxy.py b/maro/rl/workflows/train_proxy.py new file mode 100644 index 000000000..004066d50 --- /dev/null +++ b/maro/rl/workflows/train_proxy.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from maro.rl.training import TrainingProxy +from maro.rl.utils.common import get_env, int_or_none + +if __name__ == "__main__": + proxy = TrainingProxy( + frontend_port=int_or_none(get_env("TRAIN_PROXY_FRONTEND_PORT")), + backend_port=int_or_none(get_env("TRAIN_PROXY_BACKEND_PORT")), + ) + proxy.start() diff --git a/maro/rl/workflows/train_worker.py b/maro/rl/workflows/train_worker.py new file mode 100644 index 000000000..ace4e5fd4 --- /dev/null +++ b/maro/rl/workflows/train_worker.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import importlib +import os +import sys +from typing import Type + +from maro.rl.rl_component.rl_component_bundle import RLComponentBundle +from maro.rl.training import TrainOpsWorker +from maro.rl.utils.common import get_env, int_or_none +from maro.utils import LoggerV2 + +if __name__ == "__main__": + scenario_path = get_env("SCENARIO_PATH") + scenario_path = os.path.normpath(scenario_path) + sys.path.insert(0, os.path.dirname(scenario_path)) + module = importlib.import_module(os.path.basename(scenario_path)) + + rl_component_bundle_cls: Type[RLComponentBundle] = getattr(module, "rl_component_bundle_cls") + rl_component_bundle = rl_component_bundle_cls() + + worker_idx = int_or_none(get_env("ID")) + logger = LoggerV2( + f"TRAIN-WORKER.{worker_idx}", + dump_path=get_env("LOG_PATH"), + dump_mode="a", + stdout_level=get_env("LOG_LEVEL_STDOUT", required=False, default="CRITICAL"), + file_level=get_env("LOG_LEVEL_FILE", required=False, default="CRITICAL"), + ) + worker = TrainOpsWorker( + idx=int_or_none(get_env("ID")), + rl_component_bundle=rl_component_bundle, + producer_host=get_env("TRAIN_PROXY_HOST"), + producer_port=int_or_none(get_env("TRAIN_PROXY_BACKEND_PORT")), + logger=logger, + ) + worker.start() diff --git a/maro/simulator/abs_core.py b/maro/simulator/abs_core.py index cdbe0362f..4244090d2 100644 --- a/maro/simulator/abs_core.py +++ b/maro/simulator/abs_core.py @@ -162,3 +162,11 @@ def get_pending_events(self, tick: int) -> list: tick (int): Specified tick. """ pass + + def get_ticks_frame_index_mapping(self) -> dict: + """Helper method to get current available ticks to related frame index mapping. + + Returns: + dict: Dictionary of avaliable tick to frame index, it would be 1 to N mapping if the resolution is not 1. + """ + pass diff --git a/maro/simulator/core.py b/maro/simulator/core.py index fa7de129f..d5f1b2305 100644 --- a/maro/simulator/core.py +++ b/maro/simulator/core.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from collections import Iterable +from collections.abc import Iterable from importlib import import_module from inspect import getmembers, isclass from typing import Generator, List, Optional, Tuple @@ -14,7 +14,6 @@ from .abs_core import AbsEnv, DecisionMode from .scenarios.abs_business_engine import AbsBusinessEngine -from .utils import random from .utils.common import tick_to_frame_index @@ -181,9 +180,8 @@ def set_seed(self, seed: int) -> None: Args: seed (int): Seed to set. """ - - if seed is not None: - random.seed(seed) + assert seed is not None and isinstance(seed, int) + self._business_engine.set_seed(seed) @property def metrics(self) -> dict: @@ -207,6 +205,14 @@ def get_pending_events(self, tick) -> List[ActualEvent]: """ return self._event_buffer.get_pending_events(tick) + def get_ticks_frame_index_mapping(self) -> dict: + """Helper method to get current available ticks to related frame index mapping. + + Returns: + dict: Dictionary of avaliable tick to frame index, it would be 1 to N mapping if the resolution is not 1. + """ + return self._business_engine.get_ticks_frame_index_mapping() + def _init_business_engine(self) -> None: """Initialize business engine object. diff --git a/maro/simulator/scenarios/__init__.py b/maro/simulator/scenarios/__init__.py index 4eff718f3..42ba8498a 100644 --- a/maro/simulator/scenarios/__init__.py +++ b/maro/simulator/scenarios/__init__.py @@ -1,5 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. - from .abs_business_engine import AbsBusinessEngine diff --git a/maro/simulator/scenarios/abs_business_engine.py b/maro/simulator/scenarios/abs_business_engine.py index 382e8b506..daf2f480b 100644 --- a/maro/simulator/scenarios/abs_business_engine.py +++ b/maro/simulator/scenarios/abs_business_engine.py @@ -86,6 +86,26 @@ def frame_index(self, tick: int) -> int: """ return tick_to_frame_index(self._start_tick, tick, self._snapshot_resolution) + def get_ticks_frame_index_mapping(self) -> dict: + """Helper method to get current available ticks to related frame index mapping. + + Returns: + dict: Dictionary of avaliable tick to frame index, it would be 1 to N mapping if the resolution is not 1. + """ + mapping = {} + + if self.snapshots is not None: + frame_index_list = self.snapshots.get_frame_index_list() + + for frame_index in frame_index_list: + frame_start_tick = self._start_tick + frame_index * self._snapshot_resolution + frame_end_tick = min(self._max_tick, frame_start_tick + self._snapshot_resolution) + + for tick in range(frame_start_tick, frame_end_tick): + mapping[tick] = frame_index + + return mapping + def calc_max_snapshots(self) -> int: """Helper method to calculate total snapshot should be in snapshot list with parameters passed via constructor. @@ -154,6 +174,10 @@ def reset(self, keep_seed: bool = False) -> None: """Reset states business engine.""" pass + @abstractmethod + def set_seed(self, seed: int) -> None: + raise NotImplementedError + def post_step(self, tick: int) -> bool: """This method will be called at the end of each tick, used to post-process for each tick, for complex business logic with many events, it maybe not easy to determine diff --git a/maro/simulator/scenarios/cim/business_engine.py b/maro/simulator/scenarios/cim/business_engine.py index ccee78268..ec7a1c9cf 100644 --- a/maro/simulator/scenarios/cim/business_engine.py +++ b/maro/simulator/scenarios/cim/business_engine.py @@ -215,6 +215,9 @@ def reset(self, keep_seed: bool = False): self._total_operate_num = 0 + def set_seed(self, seed: int) -> None: + self._data_cntr.set_seed(seed) + def action_scope(self, port_idx: int, vessel_idx: int) -> ActionScope: """Get the action scope of specified agent. diff --git a/maro/simulator/scenarios/citi_bike/business_engine.py b/maro/simulator/scenarios/citi_bike/business_engine.py index cff8170e1..6465ff89b 100644 --- a/maro/simulator/scenarios/citi_bike/business_engine.py +++ b/maro/simulator/scenarios/citi_bike/business_engine.py @@ -170,6 +170,9 @@ def reset(self, keep_seed: bool = False): self._last_date = None + def set_seed(self, seed: int) -> None: + pass + def get_agent_idx_list(self) -> List[int]: """Get a list of agent index. diff --git a/maro/simulator/scenarios/vm_scheduling/business_engine.py b/maro/simulator/scenarios/vm_scheduling/business_engine.py index d13a49f6a..2e0dc3810 100644 --- a/maro/simulator/scenarios/vm_scheduling/business_engine.py +++ b/maro/simulator/scenarios/vm_scheduling/business_engine.py @@ -106,6 +106,10 @@ def snapshots(self) -> SnapshotList: """SnapshotList: Current snapshot list.""" return self._snapshots + @property + def pm_amount(self) -> int: + return self._pm_amount + def _load_configs(self): """Load configurations.""" # Update self._config_path with current file path. @@ -438,6 +442,9 @@ def reset(self, keep_seed: bool = False): self._cpu_reader.reset() + def set_seed(self, seed: int) -> None: + pass + def _init_frame(self): self._frame = build_frame( snapshots_num=self.calc_max_snapshots(), diff --git a/maro/utils/__init__.py b/maro/utils/__init__.py index 327249d4a..0ec34785e 100644 --- a/maro/utils/__init__.py +++ b/maro/utils/__init__.py @@ -1,9 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from .logger import DummyLogger, InternalLogger, LogFormat, Logger +from .logger import DummyLogger, LogFormat, Logger, LoggerV2 from .utils import DottableDict, clone, convert_dottable, set_seeds __all__ = [ - "Logger", "InternalLogger", "DummyLogger", "LogFormat", "convert_dottable", "DottableDict", "clone", "set_seeds" + "Logger", "LoggerV2", "DummyLogger", "LogFormat", "convert_dottable", "DottableDict", "clone", "set_seeds" ] diff --git a/maro/utils/exception/error_code.py b/maro/utils/exception/error_code.py index 7d3f9c13d..140a0a1da 100644 --- a/maro/utils/exception/error_code.py +++ b/maro/utils/exception/error_code.py @@ -47,7 +47,5 @@ 3003: "Deployment Error", # 4000-4999: Error codes for RL toolkit - 4000: "Store Misalignment", - 4001: "Missing Optimizer", - 4002: "Unrecognized Task", + 4000: "Missing Trainer", } diff --git a/maro/utils/exception/rl_toolkit_exception.py b/maro/utils/exception/rl_toolkit_exception.py index 759b8ac86..a6d6054dd 100644 --- a/maro/utils/exception/rl_toolkit_exception.py +++ b/maro/utils/exception/rl_toolkit_exception.py @@ -4,20 +4,9 @@ from .base_exception import MAROException -class StoreMisalignment(MAROException): - """Raised when a ``put`` operation on a ``SimpleStore`` would cause the underlying lists to have different - sizes.""" +class MissingTrainer(MAROException): + """ + Raised when the trainer specified in the prefix of a policy name is missing. + """ def __init__(self, msg: str = None): super().__init__(4000, msg) - - -class MissingOptimizer(MAROException): - """Raised when the optimizers are missing when calling CoreModel's step() method.""" - def __init__(self, msg: str = None): - super().__init__(4001, msg) - - -class UnrecognizedTask(MAROException): - """Raised when a CoreModel has task names that are not unrecognized by an algorithm.""" - def __init__(self, msg: str = None): - super().__init__(4002, msg) diff --git a/maro/utils/logger.py b/maro/utils/logger.py index b568c6731..37dcc0936 100644 --- a/maro/utils/logger.py +++ b/maro/utils/logger.py @@ -36,6 +36,7 @@ class LogFormat(Enum): cli_debug = 4 cli_info = 5 none = 6 + time_only = 7 FORMAT_NAME_TO_FILE_FORMAT = { @@ -49,7 +50,8 @@ class LogFormat(Enum): fmt="%(asctime)s | %(levelname)-7s | %(threadName)-10s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"), LogFormat.cli_info: logging.Formatter( fmt="%(asctime)s | %(levelname)-7s | %(threadName)-10s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"), - LogFormat.none: None + LogFormat.none: None, + LogFormat.time_only: logging.Formatter(fmt='%(asctime)s | %(message)s', datefmt='%H:%M:%S'), } FORMAT_NAME_TO_STDOUT_FORMAT = { @@ -61,7 +63,7 @@ class LogFormat(Enum): PROGRESS = 60 logging.addLevelName(PROGRESS, "PROGRESS") -level_map = { +LEVEL_MAP = { "DEBUG": logging.DEBUG, "INFO": logging.INFO, "WARN": logging.WARN, @@ -85,17 +87,14 @@ def _msgformatter(self, msg, *args): class Logger(object): """A simple wrapper for logging. - The Logger hosts a file handler and a stdout handler. The file handler is set to ``DEBUG`` level and will dump all the logging info to the given ``dump_folder``. The logging level of the stdout handler is decided by the ``stdout_level``, and can be redirected by setting the environment variable ``LOG_LEVEL``. Supported ``LOG_LEVEL`` includes: ``DEBUG``, ``INFO``, ``WARN``, ``ERROR``, ``CRITICAL``, ``PROCESS``. - Example: ``$ export LOG_LEVEL=INFO`` - Args: tag (str): Log tag for stream and file output. format_ (LogFormat): Predefined formatter. Defaults to ``LogFormat.full``. @@ -205,22 +204,6 @@ def critical(self, msg, *args): pass -class InternalLogger(Logger): - """An internal logger uses for recording the internal system's log.""" - - def __init__( - self, component_name: str, tag: str = "maro_internal", format_: LogFormat = LogFormat.internal, - dump_folder: str = None, dump_mode: str = 'a', extension_name: str = 'log', - auto_timestamp: bool = False - ): - current_time = f"{datetime.now().strftime('%Y%m%d%H%M')}" - self._dump_folder = dump_folder if dump_folder else \ - os.path.join(os.path.expanduser("~"), ".maro/log", current_time, str(os.getpid())) - super().__init__(tag, format_, self._dump_folder, dump_mode, extension_name, auto_timestamp) - - self._extra = {'component': component_name} - - class CliLogger: """An internal logger for CLI logging. @@ -233,18 +216,16 @@ def __init__(self): """Init singleton logger based on the ``--debug`` argument.""" self.log_level = CliGlobalParams.LOG_LEVEL current_time = f"{datetime.now().strftime('%Y%m%d')}" - self._dump_folder = os.path.join(os.path.expanduser("~/.maro/log/cli"), current_time) + dump_path = os.path.join(os.path.expanduser("~/.maro/log/cli"), current_time) if self.log_level == logging.DEBUG: super().__init__( tag='cli', - format_=LogFormat.cli_debug, dump_folder=self._dump_folder, - dump_mode='a', extension_name='log', auto_timestamp=False, stdout_level=self.log_level + format_=LogFormat.cli_debug, dump_path=dump_path, dump_mode='a', stdout_level=self.log_level ) elif self.log_level >= logging.INFO: super().__init__( tag='cli', - format_=LogFormat.cli_info, dump_folder=self._dump_folder, - dump_mode='a', extension_name='log', auto_timestamp=False, stdout_level=self.log_level + format_=LogFormat.cli_info, dump_path=dump_path, dump_mode='a', stdout_level=self.log_level ) _logger = None @@ -337,3 +318,78 @@ def error_red(self, message: str) -> None: """ self.passive_init() self._logger.error('\033[31m' + message + '\033[0m') + + +class LoggerV2(object): + """A simple wrapper for logging. + + The Logger hosts a file handler and a stdout handler. The file handler is set + to ``DEBUG`` level and will dump all logs info to the given ``dump_path``. + Supported log levels include: ``DEBUG``, ``INFO``, ``WARN``, ``ERROR``, ``CRITICAL``, ``PROCESS``. + + Args: + tag (str): Log tag for stream and file output. + format_ (LogFormat): Predefined formatter. Defaults to ``LogFormat.full``. + dump_path (str): Path of file for dumping logs. Must be an absolute path. The log level for dumping is + ``logging.DEBUG``. Defaults to None, in which case logs generated by the logger will not be dumped + to a file. + dump_mode (str): Write log file mode. Defaults to ``w``. Use ``a`` to append log. + stdout_level (str): the logging level of the stdout handler. Defaults to ``INFO``. + file_level (str): the logging level of the file handler. Defaults to ``DEBUG``. + """ + + def __init__( + self, tag: str, format_: LogFormat = LogFormat.simple, dump_path: str = None, dump_mode: str = 'w', + stdout_level="INFO", file_level="DEBUG" + ): + self._file_format = FORMAT_NAME_TO_FILE_FORMAT[format_] + self._stdout_format = FORMAT_NAME_TO_STDOUT_FORMAT[format_] \ + if format_ in FORMAT_NAME_TO_STDOUT_FORMAT else \ + FORMAT_NAME_TO_FILE_FORMAT[format_] + self._stdout_level = LEVEL_MAP[stdout_level] if isinstance(stdout_level, str) else stdout_level + self._file_level = LEVEL_MAP[file_level] if isinstance(file_level, str) else file_level + self._logger = logging.getLogger(tag) + self._logger.setLevel(logging.DEBUG) + + if dump_path: + os.makedirs(os.path.dirname(dump_path), exist_ok=True) + # File handler + fh = logging.FileHandler(filename=dump_path, mode=dump_mode, encoding="utf-8") + fh.setLevel(self._file_level) + if self._file_format is not None: + fh.setFormatter(self._file_format) + self._logger.addHandler(fh) + + # Stdout handler + sh = logging.StreamHandler(sys.stdout) + sh.setLevel(self._stdout_level) + if self._stdout_format is not None: + sh.setFormatter(self._stdout_format) + self._logger.addHandler(sh) + + self._extra = {'host': socket.gethostname(), 'user': getpass.getuser(), 'tag': tag} + + @msgformat + def debug(self, msg, *args): + """Add a log with ``DEBUG`` level.""" + self._logger.debug(msg, *args, extra=self._extra) + + @msgformat + def info(self, msg, *args): + """Add a log with ``INFO`` level.""" + self._logger.info(msg, *args, extra=self._extra) + + @msgformat + def warn(self, msg, *args): + """Add a log with ``WARN`` level.""" + self._logger.warning(msg, *args, extra=self._extra) + + @msgformat + def error(self, msg, *args): + """Add a log with ``ERROR`` level.""" + self._logger.error(msg, *args, extra=self._extra) + + @msgformat + def critical(self, msg, *args): + """Add a log with ``CRITICAL`` level.""" + self._logger.critical(msg, *args, extra=self._extra) diff --git a/maro/utils/utils.py b/maro/utils/utils.py index dd83f5f7b..4131488ba 100644 --- a/maro/utils/utils.py +++ b/maro/utils/utils.py @@ -78,7 +78,8 @@ def set_seeds(seed): version_file_path = os.path.join(os.path.expanduser("~/.maro"), "version.ini") -project_root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..") +LOCAL_MARO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) +project_root = os.path.join(LOCAL_MARO_ROOT, "maro") target_source_pairs = [ ( diff --git a/notebooks/articles/simple_bike_repositioning.ipynb b/notebooks/articles/simple_bike_repositioning.ipynb index 94fa47fa0..74eae4066 100644 --- a/notebooks/articles/simple_bike_repositioning.ipynb +++ b/notebooks/articles/simple_bike_repositioning.ipynb @@ -728,7 +728,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/container_inventory_management/rl_formulation.ipynb b/notebooks/container_inventory_management/rl_formulation.ipynb deleted file mode 100644 index c4360d11a..000000000 --- a/notebooks/container_inventory_management/rl_formulation.ipynb +++ /dev/null @@ -1,358 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Quick Start\n", - "\n", - "This notebook demonstrates how to use MARO's reinforcement learning (RL) toolkit to solve the container inventory management ([CIM](https://maro.readthedocs.io/en/latest/scenarios/container_inventory_management.html)) problem. It is formalized as a multi-agent reinforcement learning problem, where each port acts as a decision agent. When a vessel arrives at a port, these agents must take actions by transfering a certain amount of containers to / from the vessel. The objective is for the agents to learn policies that minimize the cumulative container shortage. " - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "\n", - "# Common info\n", - "common_config = {\n", - " \"port_attributes\": [\"empty\", \"full\", \"on_shipper\", \"on_consignee\", \"booking\", \"shortage\", \"fulfillment\"],\n", - " \"vessel_attributes\": [\"empty\", \"full\", \"remaining_space\"],\n", - " \"action_space\": list(np.linspace(-1.0, 1.0, 21)),\n", - " # Parameters for computing states\n", - " \"look_back\": 7,\n", - " \"max_ports_downstream\": 2,\n", - " # Parameters for computing actions\n", - " \"finite_vessel_space\": True,\n", - " \"has_early_discharge\": True,\n", - " # Parameters for computing rewards\n", - " \"reward_time_window\": 99,\n", - " \"fulfillment_factor\": 1.0,\n", - " \"shortage_factor\": 1.0,\n", - " \"time_decay\": 0.97\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Shaping" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "from collections import defaultdict\n", - "import numpy as np\n", - "from maro.rl import Trajectory\n", - "from maro.simulator.scenarios.cim.common import Action, ActionType\n", - "\n", - "\n", - "class CIMTrajectory(Trajectory):\n", - " def __init__(\n", - " self, env, *, port_attributes, vessel_attributes, action_space, look_back, max_ports_downstream,\n", - " reward_time_window, fulfillment_factor, shortage_factor, time_decay,\n", - " finite_vessel_space=True, has_early_discharge=True \n", - " ):\n", - " super().__init__(env)\n", - " self.port_attributes = port_attributes\n", - " self.vessel_attributes = vessel_attributes\n", - " self.action_space = action_space\n", - " self.look_back = look_back\n", - " self.max_ports_downstream = max_ports_downstream\n", - " self.reward_time_window = reward_time_window\n", - " self.fulfillment_factor = fulfillment_factor\n", - " self.shortage_factor = shortage_factor\n", - " self.time_decay = time_decay\n", - " self.finite_vessel_space = finite_vessel_space\n", - " self.has_early_discharge = has_early_discharge\n", - "\n", - " def get_state(self, event):\n", - " vessel_snapshots, port_snapshots = self.env.snapshot_list[\"vessels\"], self.env.snapshot_list[\"ports\"]\n", - " tick, port_idx, vessel_idx = event.tick, event.port_idx, event.vessel_idx\n", - " ticks = [max(0, tick - rt) for rt in range(self.look_back - 1)]\n", - " future_port_idx_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int')\n", - " port_features = port_snapshots[ticks: [port_idx] + list(future_port_idx_list): self.port_attributes]\n", - " vessel_features = vessel_snapshots[tick: vessel_idx: self.vessel_attributes]\n", - " return {port_idx: np.concatenate((port_features, vessel_features))}\n", - "\n", - " def get_action(self, action_by_agent, event):\n", - " vessel_snapshots = self.env.snapshot_list[\"vessels\"]\n", - " action_info = list(action_by_agent.values())[0]\n", - " model_action = action_info[0] if isinstance(action_info, tuple) else action_info\n", - " scope, tick, port, vessel = event.action_scope, event.tick, event.port_idx, event.vessel_idx\n", - " zero_action_idx = len(self.action_space) / 2 # index corresponding to value zero.\n", - " vessel_space = vessel_snapshots[tick:vessel:self.vessel_attributes][2] if self.finite_vessel_space else float(\"inf\")\n", - " early_discharge = vessel_snapshots[tick:vessel:\"early_discharge\"][0] if self.has_early_discharge else 0\n", - " percent = abs(self.action_space[model_action])\n", - "\n", - " if model_action < zero_action_idx:\n", - " action_type = ActionType.LOAD\n", - " actual_action = min(round(percent * scope.load), vessel_space)\n", - " elif model_action > zero_action_idx:\n", - " action_type = ActionType.DISCHARGE\n", - " plan_action = percent * (scope.discharge + early_discharge) - early_discharge\n", - " actual_action = round(plan_action) if plan_action > 0 else round(percent * scope.discharge)\n", - " else:\n", - " actual_action, action_type = 0, ActionType.LOAD\n", - "\n", - " return {port: Action(vessel, port, actual_action, action_type)}\n", - "\n", - " def get_offline_reward(self, event):\n", - " port_snapshots = self.env.snapshot_list[\"ports\"]\n", - " start_tick = event.tick + 1\n", - " ticks = list(range(start_tick, start_tick + self.reward_time_window))\n", - "\n", - " future_fulfillment = port_snapshots[ticks::\"fulfillment\"]\n", - " future_shortage = port_snapshots[ticks::\"shortage\"]\n", - " decay_list = [\n", - " self.time_decay ** i for i in range(self.reward_time_window)\n", - " for _ in range(future_fulfillment.shape[0] // self.reward_time_window)\n", - " ]\n", - "\n", - " tot_fulfillment = np.dot(future_fulfillment, decay_list)\n", - " tot_shortage = np.dot(future_shortage, decay_list)\n", - "\n", - " return np.float32(self.fulfillment_factor * tot_fulfillment - self.shortage_factor * tot_shortage)\n", - "\n", - " def on_env_feedback(self, event, state_by_agent, action_by_agent, reward):\n", - " self.trajectory[\"event\"].append(event)\n", - " self.trajectory[\"state\"].append(state_by_agent)\n", - " self.trajectory[\"action\"].append(action_by_agent)\n", - " \n", - " def on_finish(self):\n", - " training_data = {}\n", - " for event, state, action in zip(self.trajectory[\"event\"], self.trajectory[\"state\"], self.trajectory[\"action\"]):\n", - " agent_id = list(state.keys())[0]\n", - " data = training_data.setdefault(agent_id, {\"args\": [[] for _ in range(4)]})\n", - " data[\"args\"][0].append(state[agent_id]) # state\n", - " data[\"args\"][1].append(action[agent_id][0]) # action\n", - " data[\"args\"][2].append(action[agent_id][1]) # log_p\n", - " data[\"args\"][3].append(self.get_offline_reward(event)) # reward\n", - "\n", - " for agent_id in training_data:\n", - " training_data[agent_id][\"args\"] = [\n", - " np.asarray(vals, dtype=np.float32 if i == 3 else None)\n", - " for i, vals in enumerate(training_data[agent_id][\"args\"])\n", - " ]\n", - "\n", - " return training_data" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## [Agent](https://maro.readthedocs.io/en/latest/key_components/rl_toolkit.html#agent)\n", - "\n", - "The out-of-the-box ActorCritic is used as our agent." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "import torch.nn as nn\n", - "from torch.optim import Adam, RMSprop\n", - "\n", - "from maro.rl import ActorCritic, ActorCriticConfig, FullyConnectedBlock, OptimOption, SimpleMultiHeadModel\n", - "\n", - "# We consider the port in question as well as two downstream ports.\n", - "# We consider the states of these ports over the past 7 days plus the current day, hence the factor 8.\n", - "input_dim = (\n", - " (common_config[\"look_back\"] + 1) *\n", - " (common_config[\"max_ports_downstream\"] + 1) *\n", - " len(common_config[\"port_attributes\"]) +\n", - " len(common_config[\"vessel_attributes\"])\n", - ")\n", - "\n", - "agent_config = {\n", - " \"model\": {\n", - " \"actor\": {\n", - " \"input_dim\": input_dim,\n", - " \"output_dim\": len(common_config[\"action_space\"]),\n", - " \"hidden_dims\": [256, 128, 64],\n", - " \"activation\": nn.Tanh,\n", - " \"softmax\": True,\n", - " \"batch_norm\": False,\n", - " \"head\": True\n", - " },\n", - " \"critic\": {\n", - " \"input_dim\": input_dim,\n", - " \"output_dim\": 1,\n", - " \"hidden_dims\": [256, 128, 64],\n", - " \"activation\": nn.LeakyReLU,\n", - " \"softmax\": False,\n", - " \"batch_norm\": True,\n", - " \"head\": True\n", - " }\n", - " },\n", - " \"optimization\": {\n", - " \"actor\": OptimOption(optim_cls=Adam, optim_params={\"lr\": 0.001}),\n", - " \"critic\": OptimOption(optim_cls=RMSprop, optim_params={\"lr\": 0.001})\n", - " },\n", - " \"hyper_params\": {\n", - " \"reward_discount\": .0,\n", - " \"critic_loss_func\": nn.SmoothL1Loss(),\n", - " \"train_iters\": 10,\n", - " \"actor_loss_coefficient\": 0.1, # loss = actor_loss_coefficient * actor_loss + critic_loss\n", - " \"k\": 1, # for k-step return\n", - " \"lam\": 0.0 # lambda return coefficient\n", - " }\n", - "}\n", - "\n", - "def get_ac_agent():\n", - " actor_net = FullyConnectedBlock(**agent_config[\"model\"][\"actor\"])\n", - " critic_net = FullyConnectedBlock(**agent_config[\"model\"][\"critic\"])\n", - " ac_model = SimpleMultiHeadModel(\n", - " {\"actor\": actor_net, \"critic\": critic_net}, optim_option=agent_config[\"optimization\"],\n", - " )\n", - " return ActorCritic(ac_model, ActorCriticConfig(**agent_config[\"hyper_params\"]))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Training\n", - "\n", - "This code cell demonstrates a typical single-threaded training workflow." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "14:54:17 | LEARNER | INFO | ep-0: {'order_requirements': 2240000, 'container_shortage': 1422736, 'operation_number': 4220466}\n", - "14:54:19 | LEARNER | INFO | Agent learning finished\n", - "14:54:23 | LEARNER | INFO | ep-1: {'order_requirements': 2240000, 'container_shortage': 1330641, 'operation_number': 3919970}\n", - "14:54:24 | LEARNER | INFO | Agent learning finished\n", - "14:54:29 | LEARNER | INFO | ep-2: {'order_requirements': 2240000, 'container_shortage': 996878, 'operation_number': 3226186}\n", - "14:54:30 | LEARNER | INFO | Agent learning finished\n", - "14:54:34 | LEARNER | INFO | ep-3: {'order_requirements': 2240000, 'container_shortage': 703662, 'operation_number': 3608511}\n", - "14:54:36 | LEARNER | INFO | Agent learning finished\n", - "14:54:40 | LEARNER | INFO | ep-4: {'order_requirements': 2240000, 'container_shortage': 601934, 'operation_number': 3579281}\n", - "14:54:41 | LEARNER | INFO | Agent learning finished\n", - "14:54:45 | LEARNER | INFO | ep-5: {'order_requirements': 2240000, 'container_shortage': 629344, 'operation_number': 3456707}\n", - "14:54:47 | LEARNER | INFO | Agent learning finished\n", - "14:54:51 | LEARNER | INFO | ep-6: {'order_requirements': 2240000, 'container_shortage': 560709, 'operation_number': 3511869}\n", - "14:54:52 | LEARNER | INFO | Agent learning finished\n", - "14:54:56 | LEARNER | INFO | ep-7: {'order_requirements': 2240000, 'container_shortage': 483549, 'operation_number': 3613713}\n", - "14:54:57 | LEARNER | INFO | Agent learning finished\n", - "14:55:02 | LEARNER | INFO | ep-8: {'order_requirements': 2240000, 'container_shortage': 390332, 'operation_number': 3817820}\n", - "14:55:03 | LEARNER | INFO | Agent learning finished\n", - "14:55:07 | LEARNER | INFO | ep-9: {'order_requirements': 2240000, 'container_shortage': 361151, 'operation_number': 3823994}\n", - "14:55:08 | LEARNER | INFO | Agent learning finished\n", - "14:55:13 | LEARNER | INFO | ep-10: {'order_requirements': 2240000, 'container_shortage': 442086, 'operation_number': 3647343}\n", - "14:55:14 | LEARNER | INFO | Agent learning finished\n", - "14:55:18 | LEARNER | INFO | ep-11: {'order_requirements': 2240000, 'container_shortage': 390846, 'operation_number': 3784078}\n", - "14:55:19 | LEARNER | INFO | Agent learning finished\n", - "14:55:24 | LEARNER | INFO | ep-12: {'order_requirements': 2240000, 'container_shortage': 309105, 'operation_number': 3896184}\n", - "14:55:25 | LEARNER | INFO | Agent learning finished\n", - "14:55:29 | LEARNER | INFO | ep-13: {'order_requirements': 2240000, 'container_shortage': 430801, 'operation_number': 3787247}\n", - "14:55:30 | LEARNER | INFO | Agent learning finished\n", - "14:55:35 | LEARNER | INFO | ep-14: {'order_requirements': 2240000, 'container_shortage': 368042, 'operation_number': 3793428}\n", - "14:55:36 | LEARNER | INFO | Agent learning finished\n", - "14:55:40 | LEARNER | INFO | ep-15: {'order_requirements': 2240000, 'container_shortage': 383015, 'operation_number': 3829184}\n", - "14:55:41 | LEARNER | INFO | Agent learning finished\n", - "14:55:46 | LEARNER | INFO | ep-16: {'order_requirements': 2240000, 'container_shortage': 373584, 'operation_number': 3772635}\n", - "14:55:47 | LEARNER | INFO | Agent learning finished\n", - "14:55:51 | LEARNER | INFO | ep-17: {'order_requirements': 2240000, 'container_shortage': 411397, 'operation_number': 3644350}\n", - "14:55:53 | LEARNER | INFO | Agent learning finished\n", - "14:55:57 | LEARNER | INFO | ep-18: {'order_requirements': 2240000, 'container_shortage': 307861, 'operation_number': 3842550}\n", - "14:55:58 | LEARNER | INFO | Agent learning finished\n", - "14:56:02 | LEARNER | INFO | ep-19: {'order_requirements': 2240000, 'container_shortage': 324650, 'operation_number': 3848202}\n", - "14:56:04 | LEARNER | INFO | Agent learning finished\n", - "14:56:08 | LEARNER | INFO | ep-20: {'order_requirements': 2240000, 'container_shortage': 367267, 'operation_number': 3739414}\n", - "14:56:09 | LEARNER | INFO | Agent learning finished\n", - "14:56:13 | LEARNER | INFO | ep-21: {'order_requirements': 2240000, 'container_shortage': 326153, 'operation_number': 3822407}\n", - "14:56:15 | LEARNER | INFO | Agent learning finished\n", - "14:56:19 | LEARNER | INFO | ep-22: {'order_requirements': 2240000, 'container_shortage': 466237, 'operation_number': 3516845}\n", - "14:56:20 | LEARNER | INFO | Agent learning finished\n", - "14:56:25 | LEARNER | INFO | ep-23: {'order_requirements': 2240000, 'container_shortage': 429538, 'operation_number': 3603386}\n", - "14:56:26 | LEARNER | INFO | Agent learning finished\n", - "14:56:30 | LEARNER | INFO | ep-24: {'order_requirements': 2240000, 'container_shortage': 241307, 'operation_number': 3986364}\n", - "14:56:31 | LEARNER | INFO | Agent learning finished\n", - "14:56:36 | LEARNER | INFO | ep-25: {'order_requirements': 2240000, 'container_shortage': 260224, 'operation_number': 3971519}\n", - "14:56:37 | LEARNER | INFO | Agent learning finished\n", - "14:56:41 | LEARNER | INFO | ep-26: {'order_requirements': 2240000, 'container_shortage': 190507, 'operation_number': 4060439}\n", - "14:56:42 | LEARNER | INFO | Agent learning finished\n", - "14:56:47 | LEARNER | INFO | ep-27: {'order_requirements': 2240000, 'container_shortage': 152822, 'operation_number': 4146195}\n", - "14:56:48 | LEARNER | INFO | Agent learning finished\n", - "14:56:52 | LEARNER | INFO | ep-28: {'order_requirements': 2240000, 'container_shortage': 91878, 'operation_number': 4300404}\n", - "14:56:53 | LEARNER | INFO | Agent learning finished\n", - "14:56:58 | LEARNER | INFO | ep-29: {'order_requirements': 2240000, 'container_shortage': 78752, 'operation_number': 4297044}\n", - "14:56:59 | LEARNER | INFO | Agent learning finished\n", - "14:57:03 | LEARNER | INFO | ep-30: {'order_requirements': 2240000, 'container_shortage': 202098, 'operation_number': 4047921}\n", - "14:57:04 | LEARNER | INFO | Agent learning finished\n", - "14:57:09 | LEARNER | INFO | ep-31: {'order_requirements': 2240000, 'container_shortage': 161871, 'operation_number': 4113281}\n", - "14:57:10 | LEARNER | INFO | Agent learning finished\n", - "14:57:14 | LEARNER | INFO | ep-32: {'order_requirements': 2240000, 'container_shortage': 74649, 'operation_number': 4311775}\n", - "14:57:16 | LEARNER | INFO | Agent learning finished\n", - "14:57:20 | LEARNER | INFO | ep-33: {'order_requirements': 2240000, 'container_shortage': 54402, 'operation_number': 4330703}\n", - "14:57:21 | LEARNER | INFO | Agent learning finished\n", - "14:57:26 | LEARNER | INFO | ep-34: {'order_requirements': 2240000, 'container_shortage': 42802, 'operation_number': 4353353}\n", - "14:57:27 | LEARNER | INFO | Agent learning finished\n", - "14:57:31 | LEARNER | INFO | ep-35: {'order_requirements': 2240000, 'container_shortage': 49236, 'operation_number': 4346898}\n", - "14:57:32 | LEARNER | INFO | Agent learning finished\n", - "14:57:37 | LEARNER | INFO | ep-36: {'order_requirements': 2240000, 'container_shortage': 74055, 'operation_number': 4280054}\n", - "14:57:38 | LEARNER | INFO | Agent learning finished\n", - "14:57:42 | LEARNER | INFO | ep-37: {'order_requirements': 2240000, 'container_shortage': 66899, 'operation_number': 4312042}\n", - "14:57:43 | LEARNER | INFO | Agent learning finished\n", - "14:57:48 | LEARNER | INFO | ep-38: {'order_requirements': 2240000, 'container_shortage': 29641, 'operation_number': 4385481}\n", - "14:57:49 | LEARNER | INFO | Agent learning finished\n", - "14:57:53 | LEARNER | INFO | ep-39: {'order_requirements': 2240000, 'container_shortage': 56018, 'operation_number': 4354815}\n", - "14:57:54 | LEARNER | INFO | Agent learning finished\n" - ] - } - ], - "source": [ - "from maro.simulator import Env\n", - "from maro.rl import Actor, MultiAgentWrapper, OnPolicyLearner\n", - "from maro.utils import set_seeds\n", - "\n", - "set_seeds(1024) # for reproducibility\n", - "env = Env(\"cim\", \"toy.4p_ssdd_l0.0\", durations=1120)\n", - "agent = MultiAgentWrapper({name: get_ac_agent() for name in env.agent_idx_list})\n", - "actor = Actor(env, agent, CIMTrajectory, trajectory_kwargs=common_config)\n", - "learner = OnPolicyLearner(actor, 40) # 40 episodes\n", - "learner.run()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "maro", - "language": "python", - "name": "maro" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.10" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/requirements.dev.txt b/requirements.dev.txt index 68b031e73..88b71e8ad 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -1,44 +1,68 @@ -Cython==0.29.14 +add-trailing-comma +altair==4.1.0 +aria2p==0.9.1 astroid==2.3.3 +azure-identity +azure-mgmt-authorization +azure-mgmt-containerservice +azure-mgmt-resource +azure-mgmt-storage +azure-storage-file-share +black==22.3.0 certifi==2019.9.11 +cryptography==36.0.1 cycler==0.10.0 -flask==1.1.2 +Cython==0.29.14 +deepdiff==5.7.0 +docker +editorconfig-checker==2.4.0 +flake8==4.0.1 flask-cors==3.0.10 +flask==1.1.2 +flask_cors==3.0.10 +flask_socketio==5.2.0 +flloat==0.3.0 +geopy==2.0.0 guppy3==3.0.9 +holidays==0.10.3 isort==4.3.21 +jinja2==2.11.3 kiwisolver==1.1.0 +kubernetes==21.7.0 lazy-object-proxy==1.4.3 +markupsafe==2.0.1 +matplotlib==3.5.2 mccabe==0.6.1 +networkx==2.4 +networkx==2.4 +numpy<1.20.0 +palettable==3.3.0 +pandas==0.25.3 +prompt_toolkit==2.0.10 +psutil==5.8.0 +ptvsd==4.3.2 +pulp==2.6.0 pyaml==20.4.0 +PyJWT==2.4.0 pyparsing==2.4.5 python-dateutil==2.8.1 -PyYAML==5.4 +PyYAML==5.4.1 pyzmq==19.0.2 -six==1.13.0 -torch==1.6.0 -torchsummary==1.5.1 -wrapt==1.11.2 -zmq==0.0.0 -numpy<1.20.0 -tabulate==0.8.5 -networkx==2.4 -palettable==3.3.0 -urllib3==1.26.5 -geopy==2.0.0 -pandas==0.25.3 +recommonmark~=0.6.0 redis==3.5.3 requests==2.25.1 -holidays==0.10.3 +scipy==1.7.0 +setuptools==58.0.4 +six==1.13.0 sphinx==1.8.6 -recommonmark~=0.6.0 sphinx_rtd_theme==1.0.0 -jinja2==2.11.3 -flake8==4.0.1 -PuLP==2.1 streamlit==0.69.1 -altair==4.1.0 -tqdm==4.51.0 -editorconfig-checker==2.4.0 -aria2p==0.9.1 -prompt_toolkit==2.0.10 stringcase==1.2.0 +tabulate==0.8.5 +termgraph==0.5.3 +torch==1.6.0 +torchsummary==1.5.1 +tqdm==4.51.0 +urllib3==1.26.5 +wrapt==1.11.2 +zmq==0.0.0 diff --git a/setup.py b/setup.py index 67ff27bf6..6d472c72c 100644 --- a/setup.py +++ b/setup.py @@ -128,6 +128,7 @@ install_requires=[ # TODO: use a helper function to collect these "numpy<1.20.0", + "scipy<=1.7.0", "torch<1.8.0", "holidays>=0.10.3", "pyaml>=20.4.0", @@ -137,7 +138,7 @@ "psutil<5.9.0", "deepdiff>=5.2.2", "azure-storage-blob<12.9.0", - "azure-storage-common>=2.1.0", + "azure-storage-common", "geopy>=2.0.0", "pandas<1.2", "PyYAML<5.5.0", diff --git a/tests/communication/test_decorator.py b/tests/communication/test_decorator.py index 118a604ba..140415e4a 100644 --- a/tests/communication/test_decorator.py +++ b/tests/communication/test_decorator.py @@ -13,8 +13,8 @@ def handler_function(that, proxy, message): - replied_payload = {"counter": message.payload["counter"] + 1} - proxy.reply(message, payload=replied_payload) + replied_payload = {"counter": message.body["counter"] + 1} + proxy.reply(message, body=replied_payload) sys.exit(0) @@ -61,12 +61,12 @@ def test_decorator(self): message = SessionMessage( tag="unittest", source=TestDecorator.sender_proxy.name, - destination=TestDecorator.sender_proxy.peers_name["receiver"][0], - payload={"counter": 0} + destination=TestDecorator.sender_proxy.peers["receiver"][0], + body={"counter": 0} ) replied_message = TestDecorator.sender_proxy.send(message) - self.assertEqual(message.payload["counter"] + 1, replied_message[0].payload["counter"]) + self.assertEqual(message.body["counter"] + 1, replied_message[0].body["counter"]) if __name__ == "__main__": diff --git a/tests/communication/test_proxy.py b/tests/communication/test_proxy.py index e5038588f..7065cc7a9 100644 --- a/tests/communication/test_proxy.py +++ b/tests/communication/test_proxy.py @@ -11,8 +11,7 @@ def message_receive(proxy): - for received_message in proxy.receive(is_continuous=False): - return received_message.payload + return proxy.receive_once().body @unittest.skipUnless(os.environ.get("test_with_redis", False), "require redis") @@ -51,12 +50,12 @@ def test_send(self): tag="unit_test", source=TestProxy.master_proxy.name, destination=worker_proxy.name, - payload="hello_world!" + body="hello_world!" ) TestProxy.master_proxy.isend(send_msg) - for receive_message in worker_proxy.receive(is_continuous=False): - self.assertEqual(send_msg.payload, receive_message.payload) + recv_msg = worker_proxy.receive_once() + self.assertEqual(send_msg.body, recv_msg.body) def test_scatter(self): scatter_payload = ["worker_1", "worker_2", "worker_3", "worker_4", "worker_5"] @@ -72,8 +71,8 @@ def test_scatter(self): ) for i, worker_proxy in enumerate(TestProxy.worker_proxies): - for msg in worker_proxy.receive(is_continuous=False): - self.assertEqual(scatter_payload[i], msg.payload) + msg = worker_proxy.receive_once() + self.assertEqual(scatter_payload[i], msg.body) def test_broadcast(self): with ThreadPoolExecutor(max_workers=len(TestProxy.worker_proxies)) as executor: @@ -84,7 +83,7 @@ def test_broadcast(self): component_type="worker", tag="unit_test", session_type=SessionType.NOTIFICATION, - payload=payload + body=payload ) for task in all_tasks: @@ -97,15 +96,15 @@ def test_reply(self): tag="unit_test", source=TestProxy.master_proxy.name, destination=worker_proxy.name, - payload="hello " + body="hello " ) session_id_list = TestProxy.master_proxy.isend(send_msg) - for receive_message in worker_proxy.receive(is_continuous=False): - worker_proxy.reply(message=receive_message, tag="unit_test", payload="world!") + recv_message = worker_proxy.receive_once() + worker_proxy.reply(message=recv_message, tag="unit_test", body="world!") replied_msg_list = TestProxy.master_proxy.receive_by_id(session_id_list) - self.assertEqual(send_msg.payload + replied_msg_list[0].payload, "hello world!") + self.assertEqual(send_msg.body + replied_msg_list[0].body, "hello world!") if __name__ == "__main__": diff --git a/tests/communication/test_rejoin.py b/tests/communication/test_rejoin.py index 1830a5dbe..af79c2520 100644 --- a/tests/communication/test_rejoin.py +++ b/tests/communication/test_rejoin.py @@ -31,16 +31,16 @@ def actor_init(queue, redis_port): ) # Continuously receive messages from proxy. - for msg in proxy.receive(is_continuous=True): + for msg in proxy.receive(): print(f"receive message from master. {msg.tag}") if msg.tag == "cont": - proxy.reply(message=msg, tag="recv", payload="successful receive!") + proxy.reply(message=msg, tag="recv", body="successful receive!") elif msg.tag == "stop": - proxy.reply(message=msg, tag="recv", payload=f"{proxy.name} exited!") + proxy.reply(message=msg, tag="recv", body=f"{proxy.name} exited!") queue.put(proxy.name) break elif msg.tag == "finish": - proxy.reply(message=msg, tag="recv", payload=f"{proxy.name} finish!") + proxy.reply(message=msg, tag="recv", body=f"{proxy.name} finish!") sys.exit(0) proxy.close() @@ -85,7 +85,7 @@ def setUpClass(cls): **PROXY_PARAMETER ) - cls.peers = cls.master_proxy.peers_name["actor"] + cls.peers = cls.master_proxy.peers["actor"] @classmethod def tearDownClass(cls) -> None: @@ -113,7 +113,7 @@ def test_rejoin(self): tag="stop", source=TestRejoin.master_proxy.name, destination=TestRejoin.peers[1], - payload=None, + body=None, session_type=SessionType.TASK ) TestRejoin.master_proxy.isend(disconnect_message) diff --git a/tests/communication/test_zmq_driver.py b/tests/communication/test_zmq_driver.py index fcedb460a..2c98a27fb 100644 --- a/tests/communication/test_zmq_driver.py +++ b/tests/communication/test_zmq_driver.py @@ -9,8 +9,7 @@ def message_receive(driver): - for received_message in driver.receive(is_continuous=False): - return received_message.payload + return driver.receive_once().body @unittest.skipUnless(os.environ.get("test_with_zmq", False), "require zmq") @@ -45,12 +44,12 @@ def test_send(self): tag="unit_test", source="sender", destination=peer, - payload="hello_world" + body="hello_world" ) TestDriver.sender.send(message) - for received_message in TestDriver.receivers[peer].receive(is_continuous=False): - self.assertEqual(received_message.payload, message.payload) + recv_message = TestDriver.receivers[peer].receive_once() + self.assertEqual(recv_message.body, message.body) def test_broadcast(self): executor = ThreadPoolExecutor(max_workers=len(TestDriver.peer_list)) @@ -60,13 +59,13 @@ def test_broadcast(self): tag="unit_test", source="sender", destination="*", - payload="hello_world" + body="hello_world" ) TestDriver.sender.broadcast(topic="receiver", message=message) for task in as_completed(all_task): res = task.result() - self.assertEqual(res, message.payload) + self.assertEqual(res, message.body) if __name__ == "__main__": diff --git a/tests/dummy/dummy_business_engine.py b/tests/dummy/dummy_business_engine.py index fd635863b..c0ada59ff 100644 --- a/tests/dummy/dummy_business_engine.py +++ b/tests/dummy/dummy_business_engine.py @@ -54,3 +54,6 @@ def get_node_info(self): def get_agent_idx_list(self): return [node.index for node in self._dummy_list] + + def set_seed(self, seed: int) -> None: + pass diff --git a/tests/requirements.test.txt b/tests/requirements.test.txt index 3de5a8a2d..221ee85cd 100644 --- a/tests/requirements.test.txt +++ b/tests/requirements.test.txt @@ -11,7 +11,7 @@ requests<=2.26.0 psutil<5.9.0 deepdiff>=5.2.2 azure-storage-blob<12.9.0 -azure-storage-common>=2.1.0 +azure-storage-common torch<1.8.0 pytest coverage @@ -20,5 +20,4 @@ paramiko>=2.7.2 pytz==2019.3 aria2p==0.9.1 kubernetes>=12.0.1 -PyYAML<5.5.0 - +PyYAML<5.5.0 \ No newline at end of file diff --git a/tests/test_env.py b/tests/test_env.py index ae5fce9e7..be9f250ac 100644 --- a/tests/test_env.py +++ b/tests/test_env.py @@ -5,10 +5,11 @@ import unittest import numpy as np +from math import floor from .dummy.dummy_business_engine import DummyEngine from maro.simulator.utils import get_available_envs, get_scenarios, get_topologies -from maro.simulator.utils.common import frame_index_to_ticks +from maro.simulator.utils.common import frame_index_to_ticks, tick_to_frame_index from maro.simulator.core import BusinessEngineNotFoundError, Env from tests.utils import backends_to_test @@ -283,7 +284,7 @@ def test_get_avaiable_envs(self): env_list = get_available_envs() - self.assertEqual(len(env_list), len(cim_topoloies) + len(citi_bike_topologies) + len(vm_topoloties)) + self.assertEqual(len(env_list), len(cim_topoloies) + len(citi_bike_topologies) + len(vm_topoloties) + len(get_topologies("supply_chain"))) def test_frame_index_to_ticks(self): ticks = frame_index_to_ticks(0, 10, 2) @@ -293,6 +294,44 @@ def test_frame_index_to_ticks(self): self.assertListEqual([0, 1], ticks[0]) self.assertListEqual([8, 9], ticks[4]) + def test_get_avalible_frame_index_to_ticks_with_default_resolution(self): + for backend_name in backends_to_test: + os.environ["DEFAULT_BACKEND_NAME"] = backend_name + + max_tick = 10 + + env = Env(scenario="cim", topology="tests/data/cim/customized_config", + start_tick=0, durations=max_tick) + + run_to_end(env) + + t2f_mapping = env.get_ticks_frame_index_mapping() + + # tick == frame index + self.assertListEqual([t for t in t2f_mapping.keys()], [t for t in range(max_tick)]) + self.assertListEqual([f for f in t2f_mapping.values()], [f for f in range(max_tick)]) + + def test_get_avalible_frame_index_to_ticks_with_resolution2(self): + for backend_name in backends_to_test: + os.environ["DEFAULT_BACKEND_NAME"] = backend_name + + max_tick = 10 + start_tick = 0 + resolution = 2 + + env = Env(scenario="cim", topology="tests/data/cim/customized_config", + start_tick=start_tick, durations=max_tick, snapshot_resolution=resolution) + + run_to_end(env) + + t2f_mapping = env.get_ticks_frame_index_mapping() + + self.assertListEqual([t for t in t2f_mapping.keys()], [t for t in range(max_tick)]) + + for t, v in t2f_mapping.items(): + v2 = tick_to_frame_index(start_tick, t, resolution) + self.assertEqual(v, v2) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_store.py b/tests/test_store.py deleted file mode 100644 index fb6084277..000000000 --- a/tests/test_store.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import unittest - -from maro.rl import SimpleStore, OverwriteType - - -class TestUnboundedStore(unittest.TestCase): - def setUp(self) -> None: - self.store = SimpleStore(["a", "b", "c"]) - - def tearDown(self) -> None: - self.store.clear() - - def test_put(self): - indexes = self.store.put({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) - expected = [0, 1, 2] - self.assertEqual(indexes, expected, msg=f"expected returned indexes = {expected}, got {indexes}") - indexes = self.store.put({"a": [10, 11], "b": [12, 13], "c": [14, 15]}) - expected = [3, 4] - self.assertEqual(indexes, expected, msg=f"expected returned indexes = {expected}, got {indexes}") - - def test_get(self): - self.store.put({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8], "c": [9, 10, 11, 12]}) - indexes = [1, 3] - actual = self.store.get(indexes) - expected = {"a": [2, 4], "b": [6, 8], "c": [10, 12]} - self.assertEqual(actual, expected, msg=f"expected {expected}, got {actual}") - - def test_update(self): - self.store.put({"a": [1, 2, 3, 4, 5], "b": [6, 7, 8, 9, 10], "c": [11, 12, 13, 14, 15]}) - self.store.update([0, 3], {"a": [-1, -4], "c": [-11, -14]}) - actual = self.store.dumps() - expected = {"a": [-1, 2, 3, -4, 5], "b": [6, 7, 8, 9, 10], "c": [-11, 12, 13, -14, 15]} - self.assertEqual(actual, expected, msg=f"expected store content = {expected}, got {actual}") - - def test_filter(self): - self.store.put({"a": [1, 2, 3, 4, 5], "b": [6, 7, 8, 9, 10], "c": [11, 12, 13, 14, 15]}) - result = self.store.apply_multi_filters(filters=[lambda x: x["a"] > 2, lambda x: sum(x.values()) % 2 == 0])[1] - expected = {"a": [3, 5], "b": [8, 10], "c": [13, 15]} - self.assertEqual(result, expected, msg=f"expected {expected}, got {result}") - - -class TestFixedSizeStore(unittest.TestCase): - def test_put_with_rolling_overwrite(self): - store = SimpleStore(["a", "b", "c"], capacity=5, overwrite_type=OverwriteType.ROLLING) - indexes = store.put({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) - expected = [0, 1, 2] - self.assertEqual(indexes, expected, msg=f"expected indexes = {expected}, got {indexes}") - indexes = store.put({"a": [10, 11, 12, 13], "b": [14, 15, 16, 17], "c": [18, 19, 20, 21]}) - expected = [-2, -1, 0, 1] - self.assertEqual(indexes, expected, msg=f"expected indexes = {expected}, got {indexes}") - actual = store.dumps() - expected = {"a": [12, 13, 3, 10, 11], "b": [16, 17, 6, 14, 15], "c": [20, 21, 9, 18, 19]} - self.assertEqual(actual, expected, msg=f"expected store content = {expected}, got {actual}") - - def test_put_with_random_overwrite(self): - store = SimpleStore(["a", "b", "c"], capacity=5, overwrite_type=OverwriteType.RANDOM) - indexes = store.put({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) - indexes_2 = store.put({"a": [10, 11, 12, 13], "b": [14, 15, 16, 17], "c": [18, 19, 20, 21]}) - for i in indexes_2[2:]: - self.assertIn(i, indexes, msg=f"expected overwrite index in {indexes}, got {i}") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_trajectory_utils.py b/tests/test_trajectory_utils.py deleted file mode 100644 index df05d8032..000000000 --- a/tests/test_trajectory_utils.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import unittest - -import numpy as np - -from maro.rl.utils.trajectory_utils import get_k_step_returns, get_lambda_returns - - -class TestTrajectoryUtils(unittest.TestCase): - def setUp(self) -> None: - self.rewards = np.asarray([3, 2, 4, 1, 5]) - self.values = np.asarray([4, 7, 1, 3, 6]) - self.lam = 0.6 - self.discount = 0.8 - self.k = 4 - - def test_k_step_return(self): - returns = get_k_step_returns(self.rewards, self.values, self.discount, k=self.k) - expected = np.asarray([10.1296, 8.912, 8.64, 5.8, 6.0]) - np.testing.assert_allclose(returns, expected, rtol=1e-4) - - def test_lambda_return(self): - returns = get_lambda_returns(self.rewards, self.values, self.discount, self.lam, k=self.k) - expected = np.asarray([8.1378176, 6.03712, 7.744, 5.8, 6.0]) - np.testing.assert_allclose(returns, expected, rtol=1e-4) - - -if __name__ == "__main__": - unittest.main()