Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions maro/simulator/abs_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@

class DecisionMode(IntEnum):
"""Decision mode that interactive with agent."""
# Ask agent for action one by one.
# Ask agent to take action one by one.
Sequential = 0
# Ask agent for action at same time, not supported yes.
# Ask agent to take action at same time, not supported yet.
Joint = 1


class AbsEnv(ABC):
"""The main MARO simulator abstract class, which provides interfaces to agents.

Args:
scenario (str): Scenario name under maro/sim/scenarios folder.
scenario (str): Scenario name under maro/simulator/scenarios folder.
topology (str): Topology name under specified scenario folder.
start_tick (int): Start tick of the scenario, usually used for pre-processed data streaming.
durations (int): Duration ticks of this environment from start_tick.
Expand All @@ -34,11 +34,13 @@ class AbsEnv(ABC):
options (dict): Additional parameters passed to business engine.
"""

def __init__(self, scenario: str, topology: str,
start_tick: int, durations: int, snapshot_resolution: int, max_snapshots: int,
decision_mode: DecisionMode,
business_engine_cls: type,
options: dict):
def __init__(
self, scenario: str, topology: str,
start_tick: int, durations: int, snapshot_resolution: int, max_snapshots: int,
decision_mode: DecisionMode,
business_engine_cls: type,
options: dict
):
Comment thread
ArthurJiang marked this conversation as resolved.
self._tick = start_tick
self._scenario = scenario
self._topology = topology
Expand All @@ -61,7 +63,7 @@ def step(self, action):
action (Action): Action(s) from agent.

Returns:
tuple: a tuple of (reward, decision event, is_done).
tuple: a tuple of (metrics, decision event, is_done).
"""
pass

Expand All @@ -78,8 +80,7 @@ def reset(self):
@property
@abstractmethod
def configs(self) -> dict:
"""object: Configurations of current environment,
this field would be different for different scenario."""
"""object: Configurations of current environment, this field would be different for different scenario."""
pass

@property
Expand Down Expand Up @@ -114,15 +115,14 @@ def summary(self) -> dict:
@property
@abstractmethod
def snapshot_list(self) -> SnapshotList:
"""SnapshotList: Current snapshot list, a snapshot list contains all the snapshots of frame at each tick.
"""
"""SnapshotList: Current snapshot list, a snapshot list contains all the snapshots of frame at each tick."""
pass

def set_seed(self, seed: int):
"""Set random seed used by simulator.

NOTE:
This will not set seed for python random or other packages' seed, such as numpy.
This will not set seed for Python random or other packages' seed, such as NumPy.

Args:
seed (int): Seed to set.
Expand All @@ -139,8 +139,7 @@ def metrics(self) -> dict:
return {}

def get_finished_events(self) -> List[Event]:
"""List[Event]: All events finished so far.
"""
"""List[Event]: All events finished so far."""
pass

def get_pending_events(self, tick: int) -> List[Event]:
Expand Down
127 changes: 65 additions & 62 deletions maro/simulator/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

from maro.backends.frame import FrameBase, SnapshotList
from maro.event_buffer import DECISION_EVENT, EventBuffer, EventState
from maro.utils.exception.simulator_exception import \
BusinessEngineNotFoundError
from maro.utils.exception.simulator_exception import BusinessEngineNotFoundError

from .abs_core import AbsEnv, DecisionMode
from .scenarios.abs_business_engine import AbsBusinessEngine
Expand All @@ -21,37 +20,42 @@ class Env(AbsEnv):
"""Default environment implementation using generator.

Args:
scenario (str): Scenario name under maro/sim/scenarios folder.
topology (str): Topology name under specified scenario folder,
if this point to a existing folder, then it will use this as topology for built-in scenario.
scenario (str): Scenario name under maro/simulator/scenarios folder.
topology (str): Topology name under specified scenario folder.
If it points to an existing folder, the corresponding topology will be used for the built-in scenario.
start_tick (int): Start tick of the scenario, usually used for pre-processed data streaming.
durations (int): Duration ticks of this environment from start_tick.
snapshot_resolution (int): How many ticks will take a snapshot.
max_snapshots(int): Max in-memory snapshot number, default None means keep all snapshots in memory,
when taking a snapshot, if it reaches this limitation, oldest one will be overwrote.
business_engine_cls : Class of business engine, if specified, then use it to construct be instance,
or will search internal by scenario.
max_snapshots(int): Max in-memory snapshot number.
When the number of dumped snapshots reached the limitation, oldest one will be overwrote by new one.
None means keeping all snapshots in memory. Defaults to None.
Comment thread
Jinyu-W marked this conversation as resolved.
business_engine_cls: Class of business engine. If specified, use it to construct the be instance,
or search internally by scenario.
options (dict): Additional parameters passed to business engine.
"""

def __init__(self, scenario: str = None, topology: str = None,
start_tick: int = 0, durations: int = 100, snapshot_resolution: int = 1, max_snapshots: int = None,
decision_mode: DecisionMode = DecisionMode.Sequential,
business_engine_cls: type = None,
options: dict = {}):
super().__init__(scenario, topology, start_tick, durations,
snapshot_resolution, max_snapshots, decision_mode, business_engine_cls, options)
def __init__(
self, scenario: str = None, topology: str = None,
start_tick: int = 0, durations: int = 100, snapshot_resolution: int = 1, max_snapshots: int = None,
decision_mode: DecisionMode = DecisionMode.Sequential,
business_engine_cls: type = None,
options: dict = {}
):
super().__init__(
scenario, topology, start_tick, durations,
snapshot_resolution, max_snapshots, decision_mode, business_engine_cls, options
)

self._name = f'{self._scenario}:{self._topology}' if business_engine_cls is None \
else business_engine_cls.__name__
self._business_engine: AbsBusinessEngine = None

self._event_buffer = EventBuffer()

# generator to push the simulator moving on
# The generator used to push the simulator forward.
self._simulate_generator = self._simulate()

# initialize business
# Initialize the business engine.
self._init_business_engine()

def step(self, action):
Expand All @@ -61,18 +65,18 @@ def step(self, action):
action (Action): Action(s) from agent.

Returns:
tuple: a tuple of (reward, decision event, is_done).
tuple: a tuple of (metrics, decision event, is_done).
"""
try:
reward, decision_event, _is_done = self._simulate_generator.send(
metrics, decision_event, _is_done = self._simulate_generator.send(
action)
except StopIteration:
return None, None, True

return reward, decision_event, _is_done
return metrics, decision_event, _is_done

def dump(self):
"""Dump environment for restore
"""Dump environment for restore.

NOTE:
Not implemented.
Expand All @@ -97,8 +101,7 @@ def configs(self) -> dict:

@property
def summary(self) -> dict:
"""dict: Summary about current simulator, include node details, and mappings.
"""
"""dict: Summary about current simulator, including node details and mappings."""
return {
"node_mapping": self._business_engine.get_node_mapping(),
"node_detail": self.current_frame.get_node_info()
Expand Down Expand Up @@ -126,7 +129,9 @@ def frame_index(self) -> int:

@property
def snapshot_list(self) -> SnapshotList:
"""SnapshotList: A snapshot list contains all the snapshots of frame at each tick.
"""SnapshotList: A snapshot list containing all the snapshots of frame at each dump point.

NOTE: Due to different environment configurations, the resolution of the snapshot may be different.
"""
return self._business_engine.snapshots

Expand All @@ -139,7 +144,7 @@ def set_seed(self, seed: int):
"""Set random seed used by simulator.

NOTE:
This will not set seed for python random or other packages' seed, such as numpy.
This will not set seed for Python random or other packages' seed, such as NumPy.

Args:
seed (int): Seed to set.
Expand All @@ -159,8 +164,7 @@ def metrics(self) -> dict:
return self._business_engine.get_metrics()

def get_finished_events(self):
"""List[Event]: All events finished so far.
"""
"""List[Event]: All events finished so far."""
return self._event_buffer.get_finished_events()

def get_pending_events(self, tick):
Expand All @@ -175,68 +179,68 @@ def _init_business_engine(self):
"""Initialize business engine object.

NOTE:
1. For built-in scenarios will always under "maro/simulator/scenarios" folder.
2. For external scenarios, we access the business engine class to create instance.
1. For built-in scenarios, they will always under "maro/simulator/scenarios" folder.
2. For external scenarios, the business engine instance is built with the loaded business engine class.
"""
max_tick = self._start_tick + self._durations

if self._business_engine_cls is not None:
business_class = self._business_engine_cls
else:
# combine the business engine import path
# Combine the business engine import path.
business_class_path = f'maro.simulator.scenarios.{self._scenario}.business_engine'

# load the module to find business engine for that scenario
# Load the module to find business engine for that scenario.
business_module = import_module(business_class_path)

business_class = None

for _, obj in getmembers(business_module, isclass):
if issubclass(obj, AbsBusinessEngine) and obj != AbsBusinessEngine:
# we find it
# We find it.
business_class = obj

break

if business_class is None:
raise BusinessEngineNotFoundError()

self._business_engine = business_class(event_buffer=self._event_buffer,
topology=self._topology,
start_tick=self._start_tick,
max_tick=max_tick,
snapshot_resolution=self._snapshot_resolution,
max_snapshots=self._max_snapshots,
additional_options=self._additional_options)
self._business_engine = business_class(
event_buffer=self._event_buffer,
topology=self._topology,
start_tick=self._start_tick,
max_tick=max_tick,
snapshot_resolution=self._snapshot_resolution,
max_snapshots=self._max_snapshots,
additional_options=self._additional_options
)

def _simulate(self):
"""
This is the generator to wrap each episode process.
"""
"""This is the generator to wrap each episode process."""
is_end_tick = False

while True:
# ask business engine to do thing for this tick, such as gen and push events
# we do not push events now
# Ask business engine to do thing for this tick, such as generating and pushing events.
# We do not push events now.
self._business_engine.step(self._tick)

while True:
# we keep process all the events, until no more any events
# Keep processing events, until no more events in this tick.
pending_events = self._event_buffer.execute(self._tick)

# processing pending events
# Processing pending events.
pending_event_length: int = len(pending_events)

if pending_event_length == 0:
# we have processed all the event of current tick, lets go for next tick
# We have processed all the event of current tick, lets go for next tick.
break

# insert snapshot before each action
# Insert snapshot before each action.
self._business_engine.frame.take_snapshot(self.frame_index)

decision_events = []

# append source event id to decision events, to support sequential action in joint mode
# Append source event id to decision events, to support sequential action in joint mode.
for evt in pending_events:
payload = evt.payload

Expand All @@ -247,31 +251,30 @@ def _simulate(self):
decision_events = decision_events[0] if self._decision_mode == DecisionMode.Sequential \
else decision_events

# yield current state first, and waiting for action
# Yield current state first, and waiting for action.
actions = yield self._business_engine.get_metrics(), decision_events, False

if actions is None:
actions = [] # make business engine easy to work
# Make business engine easy to work.
actions = []

if actions is not None and not isinstance(actions, Iterable):
actions = [actions]

# generate a new atom event first
# Generate a new atom event first.
action_event = self._event_buffer.gen_atom_event(self._tick, DECISION_EVENT, actions)

# 3. we just append the action into sub event of first pending cascade event
# We just append the action into sub event of first pending cascade event.
pending_events[0].state = EventState.EXECUTING
pending_events[0].immediate_event_list.append(action_event)

# TODO: support get reward after action complete here, via using event_buffer.execute

if self._decision_mode == DecisionMode.Joint:
# for joint event, we will disable following cascade event
# For joint event, we will disable following cascade event.

# we expect that first action contains a src_event_id to support joint event with sequential action
# We expect that first action contains a src_event_id to support joint event with sequential action.
action_related_event_id = None if len(actions) == 1 else getattr(actions[0], "src_event_id", None)

# if first action have decision event attached, then means support sequential action
# If the first action has a decision event attached, it means sequential action is supported.
is_support_seq_action = action_related_event_id is not None

if is_support_seq_action:
Expand All @@ -282,17 +285,17 @@ def _simulate(self):
for i in range(1, pending_event_length):
pending_events[i].state = EventState.FINISHED

# check if we should end simulation
# Check the end tick of the simulation to decide if we should end the simulation.
is_end_tick = self._business_engine.post_step(self._tick)

if is_end_tick:
break

self._tick += 1

# make sure we have no missing data
# Make sure we have no missing data.
if (self._tick + 1) % self._snapshot_resolution != 0:
self._business_engine.frame.take_snapshot(self.frame_index)

# the end
# The end.
yield self._business_engine.get_metrics(), None, True
Loading