Skip to content

Commit

Permalink
CIM scenario refinement (#400)
Browse files Browse the repository at this point in the history
* Cim scenario refinement (#394)

* CIM refinement

* Fix lint error

* Fix lint error

* Cim test coverage (#395)

* Enrich tests

* Refactor CimDataGenerator

* Refactor CIM parsers

* Minor refinement

* Fix lint error

* Fix lint error

* Fix lint error

* Minor refactor

* Type

* Add two test file folders. Make a slight change to CIM BE.

* Lint error

* Lint error

* Remove unnecessary public interfaces of CIM BE

* Cim disable auto action type detection (#399)

* Haven't been tested

* Modify document

* Add ActionType checking

* Minor

* Lint error

* Action quantity should be a position number

* Modify related docs & notebooks

* Minor

* Change test file name. Prepare to merge into master.

* .

* Minor test patch
  • Loading branch information
lihuoran committed Sep 22, 2021
1 parent 39aaa92 commit 56fcfa2
Show file tree
Hide file tree
Showing 63 changed files with 4,996 additions and 1,790 deletions.
5 changes: 2 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,5 @@ data/
maro_venv/
pyvenv.cfg
htmlcov/
.coverage

.coveragerc
.coverage
.coveragerc
2 changes: 1 addition & 1 deletion docs/source/examples/multi_agent_dqn_cim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ in the roll-out loop. In this example,
plan_action = percent * (scope.discharge + early_discharge) - early_discharge
actual_action = round(plan_action) if plan_action > 0 else round(percent * scope.discharge)
else:
actual_action, action_type = 0, None
actual_action, action_type = 0, ActionType.LOAD
return {port: Action(vessel, port, actual_action, action_type)}
Expand Down
2 changes: 1 addition & 1 deletion docs/source/scenarios/container_inventory_management.rst
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ Once we get a ``DecisionEvent`` from the environment, we should respond with an
* **vessel_idx** (int): The id of the vessel/operation object of the port/agent.
* **port_idx** (int): The id of the port/agent that take this action.
* **action_type** (ActionType): Whether to load or discharge empty containers in this action.
* **quantity** (int): The quantity of empty containers to be loaded/discharged.
* **quantity** (int): The (non-negative) quantity of empty containers to be loaded/discharged.

Example
^^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions examples/cim/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class CIMTrajectory(Trajectory):
def __init__(
self, env, *, port_attributes, vessel_attributes, action_space, look_back, max_ports_downstream,
reward_time_window, fulfillment_factor, shortage_factor, time_decay,
finite_vessel_space=True, has_early_discharge=True
finite_vessel_space=True, has_early_discharge=True
):
super().__init__(env)
self.port_attributes = port_attributes
Expand Down Expand Up @@ -72,7 +72,7 @@ def get_action(self, action_by_agent, event):
plan_action = percent * (scope.discharge + early_discharge) - early_discharge
actual_action = round(plan_action) if plan_action > 0 else round(percent * scope.discharge)
else:
actual_action, action_type = 0, None
actual_action, action_type = 0, ActionType.LOAD

return {port: Action(vessel, port, actual_action, action_type)}

Expand Down
33 changes: 18 additions & 15 deletions maro/data_lib/cim/cim_data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from .entities import (
CimBaseDataCollection, CimRealDataCollection, CimSyntheticDataCollection, NoisedItem, Order, OrderGenerateMode,
PortSetting, VesselSetting
PortSetting, SyntheticPortSetting, VesselSetting
)
from .port_buffer_tick_wrapper import PortBufferTickWrapper
from .utils import BUFFER_TICK_RAND_KEY, ORDER_NUM_RAND_KEY, apply_noise, list_sum_normalize
Expand Down Expand Up @@ -42,7 +42,7 @@ class CimBaseDataContainer(ABC):
Args:
data_collection (CimBaseDataCollection): Corresponding data collection.
"""
def __init__(self, data_collection: CimBaseDataCollection):
def __init__(self, data_collection: CimBaseDataCollection) -> None:
self._data_collection = data_collection

# wrapper for interfaces, to make it easy to use
Expand Down Expand Up @@ -151,7 +151,7 @@ def full_return_buffers(self) -> PortBufferTickWrapper:
.. code-block:: python
# Get full return buffer tick of port 0.
buffer_tick = data_cnr.full_return_buffers[0]
buffer_tick = data_cntr.full_return_buffers[0]
"""
return self._full_return_buffer_wrapper

Expand Down Expand Up @@ -209,7 +209,7 @@ def reachable_stops(self) -> VesselReachableStopsWrapper:
return self._reachable_stops_wrapper

@property
def vessel_period(self) -> int:
def vessel_period(self) -> List[int]:
"""Wrapper to get vessel's planned sailing period (without noise to complete a whole route).
Examples:
Expand Down Expand Up @@ -241,7 +241,7 @@ def reset(self):
self._is_need_reset_seed = True

def _reset_seed(self):
"""Reset internal seed for generate reproduceable data"""
"""Reset internal seed for generate reproduce-able data"""
random.reset_seed(BUFFER_TICK_RAND_KEY)

@abstractmethod
Expand Down Expand Up @@ -288,14 +288,14 @@ def get_orders(self, tick: int, total_empty_container: int) -> List[Order]:

self._is_need_reset_seed = False

if tick >= self._data_collection.max_tick:
if tick >= self._data_collection.max_tick: # pragma: no cover
warnings.warn(f"{tick} out of max tick {self._data_collection.max_tick}")
return []

return self._gen_orders(tick, total_empty_container)

def _reset_seed(self):
"""Reset internal seed for generate reproduceable data"""
"""Reset internal seed for generate reproduce-able data"""
super()._reset_seed()
random.reset_seed(ORDER_NUM_RAND_KEY)

Expand All @@ -308,6 +308,7 @@ def _gen_orders(self, tick: int, total_empty_container: int) -> List[Order]:
"""
# result
order_list: List[Order] = []
assert isinstance(self._data_collection, CimSyntheticDataCollection)
order_proportion = self._data_collection.order_proportion
order_mode = self._data_collection.order_mode
total_containers = self._data_collection.total_containers
Expand All @@ -316,7 +317,7 @@ def _gen_orders(self, tick: int, total_empty_container: int) -> List[Order]:
orders_to_gen = int(order_proportion[tick])

# if under unfixed mode, we will consider current empty container as factor
if order_mode == OrderGenerateMode.UNFIXED:
if order_mode == OrderGenerateMode.UNFIXED: # pragma: no cover. TODO: remove this mark later
delta = total_containers - total_empty_container

if orders_to_gen <= delta:
Expand All @@ -331,7 +332,9 @@ def _gen_orders(self, tick: int, total_empty_container: int) -> List[Order]:

# calculate orders distribution for each port as source
for port_idx in range(self.port_number):
source_dist: NoisedItem = self.ports[port_idx].source_proportion
port = self.ports[port_idx]
assert isinstance(port, SyntheticPortSetting)
source_dist: NoisedItem = port.source_proportion

noised_source_order_number = apply_noise(source_dist.base, source_dist.noise, random[ORDER_NUM_RAND_KEY])

Expand All @@ -346,7 +349,9 @@ def _gen_orders(self, tick: int, total_empty_container: int) -> List[Order]:
if remaining_orders == 0:
break

targets_dist: List[NoisedItem] = self.ports[port_idx].target_proportions
port = self.ports[port_idx]
assert isinstance(port, SyntheticPortSetting)
targets_dist: List[NoisedItem] = port.target_proportions

# apply noise and normalize
noised_targets_dist = list_sum_normalize(
Expand Down Expand Up @@ -403,6 +408,7 @@ def __init__(self, data_collection: CimRealDataCollection):
super().__init__(data_collection)

# orders
assert isinstance(self._data_collection, CimRealDataCollection)
self._orders: Dict[int, List[Order]] = self._data_collection.orders

def get_orders(self, tick: int, total_empty_container: int) -> List[Order]:
Expand All @@ -422,11 +428,8 @@ def get_orders(self, tick: int, total_empty_container: int) -> List[Order]:

self._is_need_reset_seed = False

if tick >= self._data_collection.max_tick:
if tick >= self._data_collection.max_tick: # pragma: no cover
warnings.warn(f"{tick} out of max tick {self._data_collection.max_tick}")
return []

if tick not in self._orders:
return []

return self._orders[tick]
return self._orders[tick] if tick in self._orders else []
14 changes: 8 additions & 6 deletions maro/data_lib/cim/cim_data_container_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,21 @@

import os
import urllib.parse
from typing import Optional

from maro.cli.data_pipeline.utils import StaticParameter
from maro.simulator.utils import random, seed

from .cim_data_container import CimBaseDataContainer, CimRealDataContainer, CimSyntheticDataContainer
from .cim_data_generator import CimDataGenerator
from .cim_data_generator import gen_cim_data
from .cim_data_loader import load_from_folder, load_real_data_from_folder
from .utils import DATA_CONTAINER_INIT_SEED_LIMIT, ROUTE_INIT_RAND_KEY


class CimDataContainerWrapper:

def __init__(self, config_path: str, max_tick: int, topology: str):
self._data_cntr: CimBaseDataContainer = None
self._data_cntr: Optional[CimBaseDataContainer] = None
self._max_tick = max_tick
self._config_path = config_path
self._start_tick = 0
Expand All @@ -39,11 +40,13 @@ def _init_data_container(self, topology_seed: int = None):
config_path=config_path, max_tick=self._max_tick, start_tick=self._start_tick,
topology_seed=topology_seed
)
elif os.path.exists(os.path.join(self._config_path, "order_proportion.csv")):
self._data_cntr = data_from_dumps(dumps_folder=self._config_path)
else:
# Real Data Mode: read data from input data files, no need for any config.yml.
self._data_cntr = data_from_files(data_folder=self._config_path)

def reset(self, keep_seed):
def reset(self, keep_seed: bool):
"""Reset data container internal state"""
if not keep_seed:
self._init_data_container(random[ROUTE_INIT_RAND_KEY].randint(0, DATA_CONTAINER_INIT_SEED_LIMIT - 1))
Expand Down Expand Up @@ -87,9 +90,8 @@ def data_from_generator(config_path: str, max_tick: int, start_tick: int = 0,
Returns:
CimSyntheticDataContainer: Data container used to provide cim data related interfaces.
"""
edg = CimDataGenerator()

data_collection = edg.gen_data(config_path, start_tick=start_tick, max_tick=max_tick, topology_seed=topology_seed)
data_collection = gen_cim_data(
config_path, start_tick=start_tick, max_tick=max_tick, topology_seed=topology_seed)

return CimSyntheticDataContainer(data_collection)

Expand Down
53 changes: 27 additions & 26 deletions maro/data_lib/cim/cim_data_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,25 @@
import numpy as np
from yaml import safe_dump

from .cim_data_generator import CimDataGenerator
from .entities import CimSyntheticDataCollection
from .cim_data_generator import gen_cim_data
from .entities import CimSyntheticDataCollection, SyntheticPortSetting


def _dump_csv_file(file_path: str, headers: List[str], line_generator: callable):
"""helper method to dump csv file
Args:
file_path(str): path of output csv file
headers(List[str]): list of header
line_generator(callable): generator function to generate line to write
"""
with open(file_path, "wt+", newline="") as fp:
writer = csv.writer(fp)

writer.writerow(headers)

for line in line_generator():
writer.writerow(line)


class CimDataDumpUtil:
Expand Down Expand Up @@ -70,7 +87,7 @@ def stop_generator():
stop.leave_tick
]

self._dump_csv_file(stops_file_path, headers, stop_generator)
_dump_csv_file(stops_file_path, headers, stop_generator)

def _dump_ports(self, output_folder: str):
"""
Expand All @@ -86,6 +103,7 @@ def _dump_ports(self, output_folder: str):

def port_generator():
for port in self._data_collection.port_settings:
assert isinstance(port, SyntheticPortSetting)
yield [
port.index,
port.name,
Expand All @@ -99,7 +117,7 @@ def port_generator():
port.full_return_buffer.noise
]

self._dump_csv_file(ports_file_path, headers, port_generator)
_dump_csv_file(ports_file_path, headers, port_generator)

def _dump_vessels(self, output_folder: str):
"""
Expand Down Expand Up @@ -137,7 +155,7 @@ def vessel_generator():
vessel.empty
]

self._dump_csv_file(vessels_file_path, headers, vessel_generator)
_dump_csv_file(vessels_file_path, headers, vessel_generator)

def _dump_routes(self, output_folder: str, route_idx2name_dict: dict):
"""
Expand All @@ -161,7 +179,7 @@ def route_generator():
point.distance_to_next_port
]

self._dump_csv_file(routes_file_path, headers, route_generator)
_dump_csv_file(routes_file_path, headers, route_generator)

def _dump_order_proportions(self, output_folder: str, port_idx2name_dict: dict):
"""
Expand All @@ -179,6 +197,7 @@ def _dump_order_proportions(self, output_folder: str, port_idx2name_dict: dict):

def order_prop_generator():
for port in ports:
assert isinstance(port, SyntheticPortSetting)
for prop in port.target_proportions:
yield [
port.name,
Expand All @@ -189,7 +208,7 @@ def order_prop_generator():
prop.noise
]

self._dump_csv_file(proportion_file_path, headers, order_prop_generator)
_dump_csv_file(proportion_file_path, headers, order_prop_generator)

def _dump_misc(self, output_folder: str):
"""
Expand All @@ -213,22 +232,6 @@ def _dump_misc(self, output_folder: str):
with open(misc_file_path, "wt+") as fp:
safe_dump(misc_items, fp)

def _dump_csv_file(self, file_path: str, headers: List[str], line_generator: callable):
"""helper method to dump csv file
Args:
file_path(str): path of output csv file
headers(List[str]): list of header
line_generator(callable): generator function to generate line to write
"""
with open(file_path, "wt+", newline="") as fp:
writer = csv.writer(fp)

writer.writerow(headers)

for line in line_generator():
writer.writerow(line)


def dump_from_config(config_file: str, output_folder: str, max_tick: int):
"""Dump cim data from config, this will call data generator to generate data , and dump it.
Expand All @@ -245,9 +248,7 @@ def dump_from_config(config_file: str, output_folder: str, max_tick: int):
assert output_folder is not None and os.path.exists(output_folder), f"Got output folder path: {output_folder}"
assert max_tick is not None and max_tick > 0, f"Got max tick: {max_tick}"

generator = CimDataGenerator()

data_collection = generator.gen_data(config_file, max_tick=max_tick, start_tick=0, topology_seed=None)
data_collection = gen_cim_data(config_file, max_tick=max_tick, start_tick=0, topology_seed=None)

dump_util = CimDataDumpUtil(data_collection)

Expand Down

0 comments on commit 56fcfa2

Please sign in to comment.