# Quick Start

This notebook demonstrates how to use MARO's reinforcement learning (RL) toolkit to solve the container inventory management ([CIM](https://maro.readthedocs.io/en/latest/scenarios/container_inventory_management.html)) problem. It is formalized as a multi-agent reinforcement learning problem, where each port acts as a decision agent. When a vessel arrives at a port, these agents must take actions by transfering a certain amount of containers to / from the vessel. The objective is for the agents to learn policies that minimize the cumulative container shortage. 

In [None]:
import numpy as np

# Common info
common_config = {
    "port_attributes": ["empty", "full", "on_shipper", "on_consignee", "booking", "shortage", "fulfillment"],
    "vessel_attributes": ["empty", "full", "remaining_space"],
    # Parameters for computing states
    "look_back": 7,
    "max_ports_downstream": 2,
    # Parameters for computing actions
    "num_actions": 21,
    "finite_vessel_space": True,
    "has_early_discharge": True,
    # Parameters for computing rewards
    "reward_eval_delay": 99,
    "fulfillment_factor": 1.0,
    "shortage_factor": 1.0,
    "time_decay": 0.97
}

## Shaping

In [None]:
import numpy as np
from maro.rl import AbsEnvWrapper
from maro.simulator.scenarios.cim.common import Action, ActionType


class CIMEnvWrapper(AbsEnvWrapper):
    def __init__(
        self, env, save_replay=True, replay_agent_ids=None, *, port_attributes, vessel_attributes, num_actions,
        look_back,max_ports_downstream, reward_eval_delay, fulfillment_factor, shortage_factor, time_decay,
        finite_vessel_space=True, has_early_discharge=True 
    ):
        super().__init__(env, save_replay=save_replay, replay_agent_ids=replay_agent_ids, reward_eval_delay=reward_eval_delay)
        self.port_attributes = port_attributes
        self.vessel_attributes = vessel_attributes
        self.action_space = list(np.linspace(-1.0, 1.0, num_actions))
        self.look_back = look_back
        self.max_ports_downstream = max_ports_downstream
        self.fulfillment_factor = fulfillment_factor
        self.shortage_factor = shortage_factor
        self.time_decay = time_decay
        self.finite_vessel_space = finite_vessel_space
        self.has_early_discharge = has_early_discharge
        self._last_action_tick = None

    def get_state(self, tick=None):
        if tick is None:
            tick = self.env.tick
        vessel_snapshots, port_snapshots = self.env.snapshot_list["vessels"], self.env.snapshot_list["ports"]
        port_idx, vessel_idx = self.event.port_idx, self.event.vessel_idx
        ticks = [max(0, tick - rt) for rt in range(self.look_back - 1)]
        future_port_idx_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int')
        port_features = port_snapshots[ticks: [port_idx] + list(future_port_idx_list): self.port_attributes]
        vessel_features = vessel_snapshots[tick: vessel_idx: self.vessel_attributes]
        self.state_info = {
            "tick": tick, "action_scope": self.event.action_scope, "port_idx": port_idx, "vessel_idx": vessel_idx
        }
        state = np.concatenate((port_features, vessel_features))
        self._last_action_tick = tick
        return {port_idx: state}

    def to_env_action(self, action_by_agent):
        vessel_snapshots = self.env.snapshot_list["vessels"]
        action_info = list(action_by_agent.values())[0]
        model_action = action_info[0] if isinstance(action_info, tuple) else action_info
        tick, port, vessel = self.state_info["tick"], self.state_info["port_idx"], self.state_info["vessel_idx"]
        zero_action_idx = len(self.action_space) / 2  # index corresponding to value zero.
        vessel_space = vessel_snapshots[tick:vessel:self.vessel_attributes][2] if self.finite_vessel_space else float("inf")
        early_discharge = vessel_snapshots[tick:vessel:"early_discharge"][0] if self.has_early_discharge else 0
        percent = abs(self.action_space[model_action])

        action_scope = self.state_info["action_scope"]
        if model_action < zero_action_idx:
            action_type = ActionType.LOAD
            actual_action = min(round(percent * action_scope.load), vessel_space)
        elif model_action > zero_action_idx:
            action_type = ActionType.DISCHARGE
            plan_action = percent * (action_scope.discharge + early_discharge) - early_discharge
            actual_action = round(plan_action) if plan_action > 0 else round(percent * action_scope.discharge)
        else:
            actual_action, action_type = 0, None

        return Action(vessel, port, actual_action, action_type)

    def get_reward(self, tick=None):
        """Delayed reward evaluation."""
        if tick is None:
            tick = self._last_action_tick
        port_snapshots = self.env.snapshot_list["ports"]
        start_tick = tick + 1
        ticks = list(range(start_tick, start_tick + self.reward_eval_delay))

        future_fulfillment = port_snapshots[ticks::"fulfillment"]
        future_shortage = port_snapshots[ticks::"shortage"]
        decay_list = [
            self.time_decay ** i for i in range(self.reward_eval_delay)
            for _ in range(future_fulfillment.shape[0] // self.reward_eval_delay)
        ]

        return {
            agent_id: np.float32(
                self.fulfillment_factor * np.dot(future_fulfillment, decay_list) - 
                self.shortage_factor * np.dot(future_shortage, decay_list)
            )
            for agent_id in self.action_history[tick]
        }

## [Policy](https://maro.readthedocs.io/en/latest/key_components/rl_toolkit.html#agent)

The out-of-the-box ActorCritic is used as our agent.

In [None]:
import torch

from maro.rl import ActorCritic, ActorCriticConfig, DiscreteACNet, ExperienceManager, FullyConnectedBlock, OptimOption

# We consider the port in question as well as two downstream ports.
# We consider the states of these ports over the past 7 days plus the current day, hence the factor 8.
input_dim = (
    (common_config["look_back"] + 1) *
    (common_config["max_ports_downstream"] + 1) *
    len(common_config["port_attributes"]) +
    len(common_config["vessel_attributes"])
)

policy_config = {
    "model": { 
        "network": {
            "actor": {
                "input_dim": input_dim,
                "output_dim": common_config["num_actions"],
                "hidden_dims": [256, 128, 64],
                "activation": "tanh",
                "softmax": True,
                "batch_norm": False,
                "head": True
            },
            "critic": {
                "input_dim": input_dim,
                "output_dim": 1,
                "hidden_dims": [256, 128, 64],
                "activation": "leaky_relu",
                "softmax": False,
                "batch_norm": True,
                "head": True
            }
        },
        "optimization": {
            "actor": OptimOption(optim_cls="adam", optim_params={"lr": 0.001}),
            "critic": OptimOption(optim_cls="rmsprop", optim_params={"lr": 0.001})
        }
    },
    "experience_manager": {
        "capacity": 10000
    },
    "algorithm_config": {
        "reward_discount": .0,
        "train_epochs": 10,
        "gradient_iters": 1,
        "actor_loss_coefficient": 0.1,  # loss = actor_loss_coefficient * actor_loss + critic_loss
        "critic_loss_cls": "smooth_l1",
    }
}


class MyACNet(DiscreteACNet):
    def forward(self, states, actor: bool = True, critic: bool = True):
        states = torch.from_numpy(np.asarray(states))
        if len(states.shape) == 1:
            states = states.unsqueeze(dim=0)

        states = states.to(self.device)
        return (
            self.component["actor"](states) if actor else None,
            self.component["critic"](states) if critic else None
        )


def get_ac_policy(name):
    actor = FullyConnectedBlock(**policy_config["model"]["network"]["actor"])
    critic = FullyConnectedBlock(**policy_config["model"]["network"]["critic"])
    ac_net = MyACNet({"actor": actor, "critic": critic}, optim_option=policy_config["model"]["optimization"])
    experience_manager = ExperienceManager(policy_config["experience_manager"]["capacity"])
    return ActorCritic(name, ac_net, experience_manager, ActorCriticConfig(**policy_config["algorithm_config"]))

## Training

This code cell demonstrates a typical single-threaded training workflow.

In [None]:
from maro.simulator import Env
from maro.rl import SimpleLearner
from maro.utils import set_seeds

set_seeds(1024)  # for reproducibility
env = Env("cim", "toy.4p_ssdd_l0.0", durations=1120)
env_wrapper = CIMEnvWrapper(env, **common_config)
policies = [get_ac_policy(id_) for id_ in env.agent_idx_list]
agent2policy = {agent_id: agent_id for agent_id in env.agent_idx_list}
learner = SimpleLearner(env_wrapper, policies, agent2policy, 40)  # 40 episodes
learner.run()