# Adding a new scenario

We will try to implement a simple logistic scenario that randomly move stocks between warehouses, and record their shortage as state.


## Data model definition

In [1]:
# define a warehouse node as our data model, with 2 attributes

from  maro.backends.frame import node, NodeAttribute, NodeBase

# our node name in frames
@node("warehouses")
class Warehouse(NodeBase):
    # attribute defined with NodeAttribute will be stored in frame
    stocks = NodeAttribute("i")

    # shortage when need stocks     
    shortage = NodeAttribute("i")  

    def __init__(self):
        
        # init stocks used to reset initialized value after reset,       
        self._init_stocks: int = 0

    def set_init_state(self, stocks:int):
        self._init_stocks = stocks # this attribute will not in frame
        self.stocks = stocks

    def reset(self):
        # since frame will clear the stocks value to 0 after reset, we need to set it again        
        self.stocks = self._init_stocks

    def _on_shortage_changed(self, node_index: int, value: int):
        # do something here on shortage changed
        pass


In [1]:
    from maro.backends.frame import node, NodeBase, NodeAttribute

    @node("mynodes") #specified node name in Frame and SnapshotList
    class MyNode(NodeBase):
        """
        Customized node should inherit from NodeBase, and use node decorator to specified node name.

        After Frame initialazing, declared attributes will be add to MyNode instances.
        """

        # attribute name in Node, with a data type "i" (integer), and default slot is 1
        my_int_attribute = NodeAttribute("i")

        # a length 2 float arrays
        my_float_array = NodeAttribute("f", 2) 


        def _on_my_int_attribute_changed(self, node_index:int, value: int):
            pass # do something here


    # after you get the instance from frame, you can access attributes with different ways.
    node = frame.mynodes[0]

    # attributes with 1 slot, can be access as normal python fields
    node.my_int_attribute = 12

    print(node.my_int_attribute)

    # attributes with 2 slot, can be access like an array object
    node.my_float_array[0] = 12.1

    print(node.my_float_array[1])

    # set values
    node.my_float_array[:] = (1.0, 2.0)

    # get all values
    print(node.my_float_array[:])

    # with list of index
    node.my_float_array[(0, 1)] = (1.1, 1.3)

    print(node.my_float_array[(0,1)])

NameError: name 'frame' is not defined

In [2]:
# define a frame used to store states of warehouse

from maro.backends.frame import FrameNode, FrameBase


# we need to specified total in-memory snapshots number at definition time
def build_frame(warehouse_number: int, total_snapshot: int):
    # since we need to specified node number at definition time, so we can use a function to delay the definition
    # customized frame must inherit from FrameBase
    class LogisticsFrame(FrameBase):
        # add ware house node into current frame, and spcified number of warehouse
        # the class of Warehouse here is used to create wrapper to access node attribute
        # NOTE: nodes contain a built-in field "index" used as id in frame which is 0 based, we need this for further using, like states quering
        warehouses = FrameNode(Warehouse, warehouse_number)
        
        def __init__(self):
            super().__init__(enable_snapshot=True, total_snapshot=total_snapshot)

    return LogisticsFrame()
        

## Events definition

In [3]:
from enum import Enum

# customized new events with enum or other prefered way
class LogisticsEvents(Enum):
    # we have one event here that will be triggered if stock arrived at destination warehouse
    STOCK_ARRIVED = "stock_arrive" # stocks arrived at destinition warehouse
    STOCK_TRANSFER = "stock_transfer" # need to move stocks from soure to destinition warehouse


## Business engine

In [4]:
# the implementation of business engine

import random
from maro.simulator.scenarios import AbsBusinessEngine
from maro.event_buffer import Event, DECISION_EVENT

class LogisticsBusinessEngine(AbsBusinessEngine):
    def __init__(self, **kwargs): # this is a short way, as there are lots of parameters here
        # first parameter is the scenario name
        super().__init__("logistics", **kwargs)

        # udpate configuration path with current file path if use built-in topology logic
        self.update_config_root_path(os.path.abspath(''))
        # or self.update_config_root_path(__file__) if not in notebook

        # then we can use self._config_path to combine the configuration file
        # we do not need it here

        # initialize the frame
        self._init_frame()

        # register events
        self._register_events()

        # initialize nodes
        self._init_nodes()

    @property
    def frame(self):
        return self._frame

    @property
    def snapshots(self):
        return self._frame.snapshots

    def step(self, tick: int):
        # we generate a stock transfer even randomly here for demo
        # you may need to read data from binary file, refer to raw_data.ipynb for more detailss
        transfer_stock_num: int = random.randint(0, 6)

        if random.random() > 0.5:
            # pick a warehouse to transfer
            source_warehouse_index = tick % len(self._warehouses)
            target_warehouse_index = (tick + 1) % len(self._warehouses)

            transfer_payload = (source_warehouse_index, target_warehouse_index, transfer_stock_num)

            event = self._event_buffer.gen_atom_event(tick, LogisticsEvents.STOCK_TRANSFER, transfer_payload)

            self._event_buffer.insert_event(event)

    def post_step(self, tick: int):
        # This is a normal way to check if we need to take snapshot
        if (tick+1) % self._snapshot_resolution == 0:
            # AbsBusinessEngine provide a default function to convert tick to frame index in snapshot list
            self._frame.take_snapshot(self.frame_index(tick))

            # clear the shortage, as we need it to be state per snapshot
            for wh in self._warehouses:
                wh.shortage = 0

        # exit at the end of tick
        return tick+1 == self._max_tick

    def rewards(self, actions):
        # calculate rewards here
        pass

    def reset(self):
        # clear frame and snapshot values
        self._frame.reset()
        self._frame.snapshots.reset()

        # init the node states again
        for wh in self._warehouses:
            wh.reset()

    def _init_frame(self):
        # AbsBusinessEngine provide a easy way to calculate max snapshot number, but you can use your own
        self._frame = build_frame(10, self.calc_max_snapshots())

        # keep the reference of MyNode list for later using
        self._warehouses = self._frame.warehouses

    def _init_nodes(self):
        # NOTE: after frame reset, these values will be reset to 0, so we need to set it again after reset
        for wh in self._warehouses:
            wh.set_init_state(wh.index * 10)

    def _register_events(self):
        # decision event, predefined in event buffer, used to recieve actions from env.step
        self._event_buffer.register_event_handler(DECISION_EVENT, self._on_action_received)

        # handle customized events
        self._event_buffer.register_event_handler(LogisticsEvents.STOCK_ARRIVED, self._on_stock_arrived)
        self._event_buffer.register_event_handler(LogisticsEvents.STOCK_TRANSFER, self._on_stock_need_transfer)

    def _on_action_received(self, evt: Event):
        action = None if evt is None else evt.payload

        # process action here

    def _on_stock_need_transfer(self, evt: Event):
        # we generated the event in step function, and process it here

        # extract payload as we passed in event
        source_warehouse_index, target_warehouse_index, transfer_stock_num = evt.payload

        source_warehouse = self._warehouses[source_warehouse_index]
        target_warehouse = self._warehouses[target_warehouse_index]

        # how many we can transfer to target
        max_stocks_to_transfer = min(transfer_stock_num, source_warehouse.stocks)

        # update shortage if there is any
        if max_stocks_to_transfer != transfer_stock_num:
            source_warehouse.shortage += transfer_stock_num - max_stocks_to_transfer

        # update stocks in source  
        source_warehouse.stocks -= max_stocks_to_transfer

        # insert a new event with a delay to simuate transfer time on the way
        transfer_time = 2

        payload = (target_warehouse_index, max_stocks_to_transfer)

        # generate a recieved event, it will be triggered 2 ticks later
        recieved_evt = self._event_buffer.gen_atom_event(evt.tick + transfer_time, LogisticsEvents.STOCK_ARRIVED, payload)

        self._event_buffer.insert_event(recieved_evt)
      

    def _on_stock_arrived(self, evt: Event):
        
        target_warehouse_index, stock_number = evt.payload

        target_warehouse = self._warehouses[target_warehouse_index]

        target_warehouse.stocks += stock_number



## Run in environment

In [5]:
from maro.simulator import Env

start_tick: int = 0
durations: int = 4

reward, decision_event = None, None

# to use a non built-in scenario, we need to pass the class of new business engine to run,
# the topology if the relative path to the specified configurtion folder, we do not have one yet
env = Env(business_engine_cls=LogisticsBusinessEngine, start_tick=0, durations=durations)

is_done = False
action = None # first action must be None for each episode, as environment use a generator internally

while not is_done:
    reward, decision_event, is_done = env.step(action)

    # choose your action
    # action = agent.choose_action()

# NOTE: we can retrieve states before env.reset

## States quering

In [6]:
# states quering with env.stapshot_list
    
# get snapshot_list for specified node by node name in frame (with node decorator)
warehouse_snapshots = env.snapshot_list["warehouses"]


In [7]:
# we can use len(env.snapshot_list["node name"]) to get node number 
warehouse_number = len(warehouse_snapshots)

warehouse_number

10

In [8]:
# quering returns 1-dim numpy array by default
stocks_all_ticks = warehouse_snapshots[::"stocks"]

stocks_all_ticks


array([ 0., 10., 20., 30., 40., 50., 60., 70., 80., 90.,  0., 10., 20.,
       30., 40., 50., 60., 70., 80., 90.,  0., 10., 19., 30., 40., 50.,
       60., 70., 80., 90.,  0., 10., 19., 30., 40., 50., 60., 70., 80.,
       90.], dtype=float32)

In [9]:
# reshape it for easy reading
# we can always reshape states by (frame_index_number, node_number, attributes_number)
stocks_all_ticks.reshape(-1, len(warehouse_snapshots), 1)

array([[[ 0.],
        [10.],
        [20.],
        [30.],
        [40.],
        [50.],
        [60.],
        [70.],
        [80.],
        [90.]],

       [[ 0.],
        [10.],
        [20.],
        [30.],
        [40.],
        [50.],
        [60.],
        [70.],
        [80.],
        [90.]],

       [[ 0.],
        [10.],
        [19.],
        [30.],
        [40.],
        [50.],
        [60.],
        [70.],
        [80.],
        [90.]],

       [[ 0.],
        [10.],
        [19.],
        [30.],
        [40.],
        [50.],
        [60.],
        [70.],
        [80.],
        [90.]]], dtype=float32)

In [10]:
#quering with multiple attributes
multiple_states_all_ticks = warehouse_snapshots[::["stocks", "shortage"]]

multiple_states_all_ticks = multiple_states_all_ticks.reshape(-1, warehouse_number, 2)

multiple_states_all_ticks

array([[[ 0.,  2.],
        [10.,  0.],
        [20.,  0.],
        [30.,  0.],
        [40.,  0.],
        [50.,  0.],
        [60.,  0.],
        [70.,  0.],
        [80.,  0.],
        [90.,  0.]],

       [[ 0.,  0.],
        [10.,  0.],
        [20.,  0.],
        [30.,  0.],
        [40.,  0.],
        [50.,  0.],
        [60.,  0.],
        [70.,  0.],
        [80.,  0.],
        [90.,  0.]],

       [[ 0.,  0.],
        [10.,  0.],
        [19.,  0.],
        [30.,  0.],
        [40.,  0.],
        [50.,  0.],
        [60.,  0.],
        [70.,  0.],
        [80.,  0.],
        [90.,  0.]],

       [[ 0.,  0.],
        [10.,  0.],
        [19.,  0.],
        [30.,  0.],
        [40.,  0.],
        [50.,  0.],
        [60.,  0.],
        [70.,  0.],
        [80.,  0.],
        [90.,  0.]]], dtype=float32)

In [11]:
# then 1 column is stocks
stocks_all_ticks_2 = multiple_states_all_ticks[:, :, 0]

stocks_all_ticks_2

array([[ 0., 10., 20., 30., 40., 50., 60., 70., 80., 90.],
       [ 0., 10., 20., 30., 40., 50., 60., 70., 80., 90.],
       [ 0., 10., 19., 30., 40., 50., 60., 70., 80., 90.],
       [ 0., 10., 19., 30., 40., 50., 60., 70., 80., 90.]], dtype=float32)

In [12]:
# 2nd is shortage

shortage_all_ticks = multiple_states_all_ticks[:, :, 1]

shortage_all_ticks

array([[2., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)

In [13]:
# quering states for one node
first_warehouse_stocks_all_ticks = warehouse_snapshots[:0:"stocks"]

first_warehouse_stocks_all_ticks

array([0., 0., 0., 0.], dtype=float32)

In [14]:
# quering states for multiple nodes
first_2_warehouse_stocks_all_ticks = warehouse_snapshots[:(0,1):"stocks"]

# after reshape, one column is for one node
first_2_warehouse_stocks_all_ticks.reshape(-1, 2)

array([[ 0., 10.],
       [ 0., 10.],
       [ 0., 10.],
       [ 0., 10.]], dtype=float32)

In [18]:
# quering for 1 frame index (same as tick for this case)
seconds_warehouse_stocks_1st_tick = warehouse_snapshots[0:1:"stocks"]

# NOTE: this value may not the initial value, as snapshot is took at the end of tick or when need a decision
seconds_warehouse_stocks_1st_tick

array([10.], dtype=float32)

In [16]:
# quering with multiple frame index (same as tick for this case)
seconds_warehouse_stocks_first_2_ticks = warehouse_snapshots[[0,1,]:1:"stocks"]

seconds_warehouse_stocks_first_2_ticks.reshape(2,1)

array([[10.],
       [10.]], dtype=float32)