# Quick Start

This notebook demonstrates the use of MARO's RL toolkit to optimize container inventory management. The scenario consists of a set of ports, each acting as a learning agent, and vessels that transfer empty containers among them. Each port must decide 1) whether to load or discharge containers when a vessel arrives and 2) how many containers to be loaded or discharged. The objective is to minimize the overall container shortage over a certain period of time.

In [5]:
# env and shaping config
env_conf = {
    "scenario": "cim",
    "topology": "toy.4p_ssdd_l0.0",
    "durations": 560
}

port_attributes = ["empty", "full", "on_shipper", "on_consignee", "booking", "shortage", "fulfillment"]
vessel_attributes = ["empty", "full", "remaining_space"]

state_shaping_conf = {
    "look_back": 7,
    "max_ports_downstream": 2
}

action_shaping_conf = {
    "action_space": [(i - 10) / 10 for i in range(21)],
    "finite_vessel_space": True,
    "has_early_discharge": True
}

reward_shaping_conf = {
    "time_window": 99,
    "fulfillment_factor": 1.0,
    "shortage_factor": 1.0,
    "time_decay": 0.97
}

## Environment Sampler

An environment sampler defines state, action and reward shaping logic so that policies can interact with the environment.

In [6]:
import numpy as np
from maro.rl.rollout import AbsEnvSampler
from maro.simulator.scenarios.cim.common import Action, ActionType


class CIMEnvSampler(AbsEnvSampler):
    def _get_global_and_agent_state(self, event, tick=None):
        """
        The state vector includes shortage and remaining vessel space over the past k days (where k is the "look_back"
        value in "state_shaping_conf" from the cell above), as well as all downstream port features.
        """
        tick = self._env.tick
        vessel_snapshots, port_snapshots = self._env.snapshot_list["vessels"], self._env.snapshot_list["ports"]
        port_idx, vessel_idx = event.port_idx, event.vessel_idx
        ticks = [max(0, tick - rt) for rt in range(state_shaping_conf["look_back"] - 1)]
        future_port_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int')
        state = np.concatenate([
            port_snapshots[ticks: [port_idx] + list(future_port_list): port_attributes],
            vessel_snapshots[tick: vessel_idx: vessel_attributes]
        ])
        return state, {port_idx: state}

    def _translate_to_env_action(self, action_dict, event):
        """
        The policy output is an integer from [0, 20] which is to be interpreted as the index of "action_space" in
        "action_shaping_conf" from the cell above. For example, action 5 corresponds to -0.5, which means loading
        50% of the containers available at the current port to the vessel, while action 18 corresponds to 0.8, which
        means loading 80% of the containers on the vessel to the port. Note that action 10 corresponds 0.0, which
        means doing nothing. 
        """
        action_space = action_shaping_conf["action_space"]
        finite_vsl_space = action_shaping_conf["finite_vessel_space"]
        has_early_discharge = action_shaping_conf["has_early_discharge"]

        port_idx, model_action = list(action_dict.items()).pop()

        vsl_idx, action_scope = event.vessel_idx, event.action_scope
        vsl_snapshots = self._env.snapshot_list["vessels"]
        vsl_space = vsl_snapshots[self._env.tick:vsl_idx:vessel_attributes][2] if finite_vsl_space else float("inf")

        percent = abs(action_space[model_action[0]])
        zero_action_idx = len(action_space) / 2  # index corresponding to value zero.
        if model_action < zero_action_idx:
            action_type = ActionType.LOAD
            actual_action = min(round(percent * action_scope.load), vsl_space)
        elif model_action > zero_action_idx:
            action_type = ActionType.DISCHARGE
            early_discharge = vsl_snapshots[self._env.tick:vsl_idx:"early_discharge"][0] if has_early_discharge else 0
            plan_action = percent * (action_scope.discharge + early_discharge) - early_discharge
            actual_action = round(plan_action) if plan_action > 0 else round(percent * action_scope.discharge)
        else:
            actual_action, action_type = 0, None

        return {port_idx: Action(vsl_idx, int(port_idx), actual_action, action_type)}

    def _get_reward(self, env_action_dict, event, tick):
        """
        The reward is defined as a linear combination of fulfillment and shortage measures. The fulfillment and
        shortage measure are the sums of fulfillment and shortage values over the next k days, respectively, each
        adjusted with exponential decay factors (using the "time_decay" value in "reward_shaping_conf" from the
        cell above) to put more emphasis on the near future. Here k is the "time_window" value in "reward_shaping_conf".
        The linear combination coefficients are given by "fulfillment_factor" and "shortage_factor" in "reward_shaping_conf".
        """
        start_tick = tick + 1
        ticks = list(range(start_tick, start_tick + reward_shaping_conf["time_window"]))

        # Get the ports that took actions at the given tick
        ports = [int(port) for port in list(env_action_dict.keys())]
        port_snapshots = self._env.snapshot_list["ports"]
        future_fulfillment = port_snapshots[ticks:ports:"fulfillment"].reshape(len(ticks), -1)
        future_shortage = port_snapshots[ticks:ports:"shortage"].reshape(len(ticks), -1)

        decay_list = [reward_shaping_conf["time_decay"] ** i for i in range(reward_shaping_conf["time_window"])]
        rewards = np.float32(
            reward_shaping_conf["fulfillment_factor"] * np.dot(future_fulfillment.T, decay_list)
            - reward_shaping_conf["shortage_factor"] * np.dot(future_shortage.T, decay_list)
        )
        return {agent_id: reward for agent_id, reward in zip(ports, rewards)}

    def get_env_metrics(self) -> None:
        """
        The environment sampler contains a "tracker" dict inherited from the "AbsEnvSampler" base class, which can
        be used to record any information one wishes to keep track of during a roll-out episode. Here we simply
        record the latest env metric without keeping the history for logging purposes.
        """
        return self._env.metrics

## [Policies](https://maro.readthedocs.io/en/latest/key_components/rl_toolkit.html#policy)

The out-of-the-box ActorCritic is used as our agent.

In [7]:
import torch
from torch.optim import Adam, RMSprop

from maro.rl.model import DiscretePolicyNet, FullyConnected, VNet
from maro.rl.policy import DiscretePolicyGradient
from maro.rl.training.algorithms import DiscreteActorCritic, DiscreteActorCriticParams

# We consider the port in question as well as two downstream ports.
# We consider the states of these ports over the past 7 days plus the current day, hence the factor 8.
# obtain state dimension from a temporary env_wrapper instance
state_dim = (
    (state_shaping_conf["look_back"] + 1) * (state_shaping_conf["max_ports_downstream"] + 1) * len(port_attributes)
    + len(vessel_attributes)
)
action_num = len(action_shaping_conf["action_space"])

# AC settings
actor_net_conf = {
    "input_dim": state_dim,
    "hidden_dims": [256, 128, 64],
    "output_dim": len(action_shaping_conf["action_space"]),
    "activation": torch.nn.Tanh,
    "softmax": True,
    "batch_norm": False,
    "head": True
}
critic_net_conf = {
    "input_dim": state_dim,
    "hidden_dims": [256, 128, 64],
    "output_dim": 1,
    "activation": torch.nn.LeakyReLU,
    "softmax": False,
    "batch_norm": True,
    "head": True
}

actor_optim_conf = (Adam, {"lr": 0.001})
critic_optim_conf = (RMSprop, {"lr": 0.001})

ac_conf = DiscreteActorCriticParams(
    get_v_critic_net_func=lambda: MyCriticNet(),
    reward_discount=.0,
    grad_iters=10,
    critic_loss_cls=torch.nn.SmoothL1Loss,
    min_logp=None,
    lam=.0
)


class MyActorNet(DiscretePolicyNet):
    def __init__(self) -> None:
        super(MyActorNet, self).__init__(state_dim=state_dim, action_num=action_num)
        self._actor = FullyConnected(**actor_net_conf)
        self._actor_optim = actor_optim_conf[0](self._actor.parameters(), **actor_optim_conf[1])

    def _get_action_probs_impl(self, states: torch.Tensor) -> torch.Tensor:
        return self._actor(states)

    def freeze(self) -> None:
        self.freeze_all_parameters()

    def unfreeze(self) -> None:
        self.unfreeze_all_parameters()

    def get_gradients(self, loss: torch.Tensor):
        self._actor_optim.zero_grad()
        loss.backward()
        return {name: param.grad for name, param in self.named_parameters()}

    def apply_gradients(self, grad: dict) -> None:
        for name, param in self.named_parameters():
            param.grad = grad[name]
        self._actor_optim.step()

    def get_net_state(self) -> dict:
        return {
            "network": self.state_dict(),
            "actor_optim": self._actor_optim.state_dict()
        }

    def set_net_state(self, net_state: dict) -> None:
        self.load_state_dict(net_state["network"])
        self._actor_optim.load_state_dict(net_state["actor_optim"])


class MyCriticNet(VNet):
    def __init__(self) -> None:
        super(MyCriticNet, self).__init__(state_dim=state_dim)
        self._critic = FullyConnected(**critic_net_conf)
        self._critic_optim = critic_optim_conf[0](self._critic.parameters(), **critic_optim_conf[1])

    def _get_v_values(self, states: torch.Tensor) -> torch.Tensor:
        return self._critic(states).squeeze(-1)

    def get_gradients(self, loss: torch.Tensor):
        self._critic_optim.zero_grad()
        loss.backward()
        return {name: param.grad for name, param in self.named_parameters()}

    def apply_gradients(self, grad: dict) -> None:
        for name, param in self.named_parameters():
            param.grad = grad[name]
        self._critic_optim.step()

    def get_net_state(self) -> dict:
        return {
            "network": self.state_dict(),
            "critic_optim": self._critic_optim.state_dict()
        }

    def set_net_state(self, net_state: dict) -> None:
        self.load_state_dict(net_state["network"])
        self._critic_optim.load_state_dict(net_state["critic_optim"])

    def freeze(self) -> None:
        self.freeze_all_parameters()

    def unfreeze(self) -> None:
        self.unfreeze_all_parameters()

policy_dict = {f"ac_{i}.policy": DiscretePolicyGradient(f"ac_{i}.policy", policy_net=MyActorNet()) for i in range(4)}
trainer_creator = {f"ac_{i}": lambda name: DiscreteActorCritic(name, params=ac_conf) for i in range(4)}

## Learning Loop

This code cell demonstrates a typical single-threaded training workflow.

In [8]:
from maro.rl.rollout import SimpleAgentWrapper
from maro.rl.training import TrainerManager 
from maro.simulator import Env
from maro.utils import set_seeds

set_seeds(1024)  # for reproducibility

agent2policy = {agent: f"ac_{agent}.policy" for agent in Env(**env_conf).agent_idx_list}

# The env sampler and trainer manager both take ``policy_creator`` as a parameter. The policy creator is
# a function that takes a name and returns a policy instance. This design is convenient in distributed
# settings where policies need to be created on both the training side and the roll-out (inference) side
# and policy states need to be transferred from the former to the latter at the start of each roll-out
# episode. Here we are demonstrating a single-threaded workflow where there is only one instance of each
# policy, so we use a little trick here to ensure that the policies created inside the env sampler and the
# training manager point the the same instances. 
policy_creator = {name: lambda name: policy_dict[name] for name in policy_dict}

env_sampler = CIMEnvSampler(
    get_env=lambda: Env(**env_conf),
    policy_creator=policy_creator,
    agent2policy=agent2policy,
    agent_wrapper_cls=SimpleAgentWrapper,
    device="cpu"
)

trainer_manager = TrainerManager(policy_creator, trainer_creator, agent2policy)

# main loop with 50 episodes
for ep in range(1, 51):
    collect_time = training_time = 0
    segment, end_of_episode = 1, False
    while not end_of_episode:
        # experience collection
        result = env_sampler.sample()
        experiences = result["experiences"]
        end_of_episode: bool = result["end_of_episode"]
        print(f"env summary (episode {ep}, segment {segment}): {env_sampler.get_env_metrics()}")
        trainer_manager.record_experiences(experiences)
        trainer_manager.train()
        segment += 1

Assign policy ac_0.policy to device cpu
Assign policy ac_1.policy to device cpu
Assign policy ac_2.policy to device cpu
Assign policy ac_3.policy to device cpu
Policy ac_0.policy has already been assigned to cpu. No need to take further actions.
Policy ac_1.policy has already been assigned to cpu. No need to take further actions.
Policy ac_2.policy has already been assigned to cpu. No need to take further actions.
Policy ac_3.policy has already been assigned to cpu. No need to take further actions.
env summary (episode 1, segment 1): {'order_requirements': 1120000, 'container_shortage': 762124, 'operation_number': 2035086}
env summary (episode 2, segment 1): {'order_requirements': 1120000, 'container_shortage': 523489, 'operation_number': 2166897}
env summary (episode 3, segment 1): {'order_requirements': 1120000, 'container_shortage': 430128, 'operation_number': 2180610}
env summary (episode 4, segment 1): {'order_requirements': 1120000, 'container_shortage': 254209, 'operation_number