# Quick Start

This notebook demonstrates the use of MARO's RL toolkit to optimize container inventory management. The scenario consists of a set of ports, each acting as a learning agent, and vessels that transfer empty containers among them. Each port must decide 1) whether to load or discharge containers when a vessel arrives and 2) how many containers to be loaded or discharged. The objective is to minimize the overall container shortage over a certain period of time.

In [None]:
# env and shaping config
env_conf = {
    "scenario": "cim",
    "topology": "toy.4p_ssdd_l0.0",
    "durations": 560
}

port_attributes = ["empty", "full", "on_shipper", "on_consignee", "booking", "shortage", "fulfillment"]
vessel_attributes = ["empty", "full", "remaining_space"]

state_shaping_conf = {
    "look_back": 7,
    "max_ports_downstream": 2
}

action_shaping_conf = {
    "action_space": [(i - 10) / 10 for i in range(21)],
    "finite_vessel_space": True,
    "has_early_discharge": True
}

reward_shaping_conf = {
    "time_window": 99,
    "fulfillment_factor": 1.0,
    "shortage_factor": 1.0,
    "time_decay": 0.97
}

## Environment Sampler

An environment sampler defines state, action and reward shaping logic so that policies can interact with the environment.

In [None]:
import numpy as np
from maro.rl.learning import AbsEnvSampler
from maro.simulator.scenarios.cim.common import Action, ActionType


class CIMEnvSampler(AbsEnvSampler):
    def get_state(self, tick=None):
        """
        The state vector includes shortage and remaining vessel space over the past k days (where k is the "look_back"
        value in "state_shaping_conf" from the cell above), as well as all downstream port features.
        """
        if tick is None:
            tick = self.env.tick
        vessel_snapshots, port_snapshots = self.env.snapshot_list["vessels"], self.env.snapshot_list["ports"]
        port_idx, vessel_idx = self.event.port_idx, self.event.vessel_idx
        ticks = [max(0, tick - rt) for rt in range(state_shaping_conf["look_back"] - 1)]
        future_port_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int') 
        state = np.concatenate([
            port_snapshots[ticks : [port_idx] + list(future_port_list) : port_attributes],
            vessel_snapshots[tick : vessel_idx : vessel_attributes]
        ])
        return {port_idx: state}

    def get_env_actions(self, action_by_agent):
        """
        The policy output is an integer from [0, 20] which is to be interpreted as the index of "action_space" in
        "action_shaping_conf" from the cell above. For example, action 5 corresponds to -0.5, which means loading
        50% of the containers available at the current port to the vessel, while action 18 corresponds to 0.8, which
        means loading 80% of the containers on the vessel to the port. Note that action 10 corresponds 0.0, which
        means doing nothing. 
        """
        action_space = action_shaping_conf["action_space"]
        finite_vsl_space = action_shaping_conf["finite_vessel_space"]
        has_early_discharge = action_shaping_conf["has_early_discharge"]

        port_idx, action = list(action_by_agent.items()).pop()
        vsl_idx, action_scope = self.event.vessel_idx, self.event.action_scope
        vsl_snapshots = self.env.snapshot_list["vessels"]
        vsl_space = vsl_snapshots[self.env.tick:vsl_idx:vessel_attributes][2] if finite_vsl_space else float("inf")

        model_action = action["action"] if isinstance(action, dict) else action    
        percent = abs(action_space[model_action])
        zero_action_idx = len(action_space) / 2  # index corresponding to value zero.
        if model_action < zero_action_idx:
            action_type = ActionType.LOAD
            actual_action = min(round(percent * action_scope.load), vsl_space)
        elif model_action > zero_action_idx:
            action_type = ActionType.DISCHARGE
            early_discharge = vsl_snapshots[self.env.tick:vsl_idx:"early_discharge"][0] if has_early_discharge else 0
            plan_action = percent * (action_scope.discharge + early_discharge) - early_discharge
            actual_action = round(plan_action) if plan_action > 0 else round(percent * action_scope.discharge)
        else:
            actual_action, action_type = 0, ActionType.LOAD

        return [Action(port_idx=port_idx, vessel_idx=vsl_idx, quantity=actual_action, action_type=action_type)]

    def get_reward(self, actions, tick):
        """
        The reward is defined as a linear combination of fulfillment and shortage measures. The fulfillment and
        shortage measure are the sums of fulfillment and shortage values over the next k days, respectively, each
        adjusted with exponential decay factors (using the "time_decay" value in "reward_shaping_conf" from the
        cell above) to put more emphasis on the near future. Here k is the "time_window" value in "reward_shaping_conf".
        The linear combination coefficients are given by "fulfillment_factor" and "shortage_factor" in "reward_shaping_conf".
        """
        start_tick = tick + 1
        ticks = list(range(start_tick, start_tick + reward_shaping_conf["time_window"]))

        # Get the ports that took actions at the given tick
        ports = [action.port_idx for action in actions]
        port_snapshots = self.env.snapshot_list["ports"]
        future_fulfillment = port_snapshots[ticks:ports:"fulfillment"].reshape(len(ticks), -1)
        future_shortage = port_snapshots[ticks:ports:"shortage"].reshape(len(ticks), -1)

        decay_list = [reward_shaping_conf["time_decay"] ** i for i in range(reward_shaping_conf["time_window"])]
        rewards = np.float32(
            reward_shaping_conf["fulfillment_factor"] * np.dot(future_fulfillment.T, decay_list)
            - reward_shaping_conf["shortage_factor"] * np.dot(future_shortage.T, decay_list)
        )
        return {agent_id: reward for agent_id, reward in zip(ports, rewards)}

    def post_step(self, state, action, env_action, reward, tick):
        """
        The environment sampler contains a "tracker" dict inherited from the "AbsEnvSampler" base class, which can
        be used to record any information one wishes to keep track of during a roll-out episode. Here we simply
        record the latest env metric without keeping the history for logging purposes.
        """
        self.tracker["env_metric"] = self.env.metrics

## [Policies](https://maro.readthedocs.io/en/latest/key_components/rl_toolkit.html#policy)

The out-of-the-box ActorCritic is used as our agent.

In [None]:
import torch
from torch.optim import Adam, RMSprop

from maro.rl.exploration import MultiLinearExplorationScheduler, epsilon_greedy
from maro.rl.modeling import DiscreteACNet, FullyConnected
from maro.rl.policy import ActorCritic

# We consider the port in question as well as two downstream ports.
# We consider the states of these ports over the past 7 days plus the current day, hence the factor 8.
# obtain state dimension from a temporary env_wrapper instance
state_dim = (
    (state_shaping_conf["look_back"] + 1) * (state_shaping_conf["max_ports_downstream"] + 1) * len(port_attributes)
    + len(vessel_attributes)
)

# AC settings
actor_net_conf = {
    "input_dim": state_dim,
    "hidden_dims": [256, 128, 64],
    "output_dim": len(action_shaping_conf["action_space"]),
    "activation": torch.nn.Tanh,
    "softmax": True,
    "batch_norm": False,
    "head": True
}

critic_net_conf = {
    "input_dim": state_dim,
    "hidden_dims": [256, 128, 64],
    "output_dim": 1,
    "activation": torch.nn.LeakyReLU,
    "softmax": False,
    "batch_norm": True,
    "head": True
}

actor_optim_conf = (Adam, {"lr": 0.001})
critic_optim_conf = (RMSprop, {"lr": 0.001})

ac_conf = {
    "reward_discount": .0,
    "grad_iters": 10,
    "critic_loss_cls": torch.nn.SmoothL1Loss,
    "min_logp": None,
    "critic_loss_coeff": 0.1,
    "entropy_coeff": 0.01,
    # "clip_ratio": 0.8   # for PPO
    "lam": .0,
    "get_loss_on_rollout": False
}


class MyACNet(DiscreteACNet):
    def __init__(self):
        super().__init__()
        self.actor = FullyConnected(**actor_net_conf)
        self.critic = FullyConnected(**critic_net_conf)
        self.actor_optim = actor_optim_conf[0](self.actor.parameters(), **actor_optim_conf[1])
        self.critic_optim = critic_optim_conf[0](self.critic.parameters(), **critic_optim_conf[1])

    @property
    def input_dim(self):
        return state_dim

    @property
    def num_actions(self):
        return q_net_conf["output_dim"]

    def forward(self, states, actor: bool = True, critic: bool = True):
        return (self.actor(states) if actor else None), (self.critic(states) if critic else None)

    def step(self, loss):
        self.actor_optim.zero_grad()
        self.critic_optim.zero_grad()
        loss.backward()
        self.actor_optim.step()
        self.critic_optim.step()

    def get_gradients(self, loss):
        self.actor_optim.zero_grad()
        self.critic_optim.zero_grad()
        loss.backward()
        return {name: param.grad for name, param in self.named_parameters()}

    def apply_gradients(self, grad):
        for name, param in self.named_parameters():
            param.grad = grad[name]

        self.actor_optim.step()
        self.critic_optim.step()


policy_func_dict = {f"ac.{i}": lambda name: ActorCritic(name, MyACNet(), **ac_conf) for i in range(4)}

## Learning Loop

This code cell demonstrates a typical single-threaded training workflow.

In [None]:
from maro.simulator import Env
from maro.rl.learning import learn
from maro.utils import set_seeds

def get_env_sampler():
    return CIMEnvSampler(
        get_env=lambda: Env(**env_conf),
        get_policy_func_dict=policy_func_dict,
        agent2policy={agent: f"ac.{agent}" for agent in Env(**env_conf).agent_idx_list},
        reward_eval_delay=reward_shaping_conf["time_window"]
    )

# post-episode callback, executed at the end of an episode or episode segment.
def post_collect(trackers, ep, segment):
    """
    Print the metric recorded in the env tracker at the end of an episode. The parameter "trackers" is actually
    a list because in a distributed setting, the main thread usually receives trackers from multiple roll-out
    instances.
    """
    print(f"env summary (episode {ep}, segment {segment}): {trackers[0]['env_metric']}")

# post-evaluation callback, executed at the end of an evaluation episode.
def post_evaluate(trackers, ep):
    """
    Print the metric recorded in the env tracker at the end of an evaluation episode. The parameter "trackers"
    is actually a list because in a distributed setting, the main thread usually receives trackers from multiple
    roll-out instances.
    """
    print(f"env summary (evaluation episode {ep}): {trackers[0]['env_metric']}")


set_seeds(1024)  # for reproducibility
learn(
    get_env_sampler, num_episodes=50, eval_after_last_episode=True,
    post_collect=post_collect, post_evaluate=post_evaluate
)