# Quick Start

This notebook demonstrates how to use MARO's reinforcement learning (RL) toolkit to solve the container inventory management ([CIM](https://maro.readthedocs.io/en/latest/scenarios/container_inventory_management.html)) problem. It is formalized as a multi-agent reinforcement learning problem, where each port acts as a decision agent. When a vessel arrives at a port, these agents must take actions by transfering a certain amount of containers to / from the vessel. The objective is for the agents to learn policies that minimize the cumulative container shortage. 

In [None]:
# env and shaping config
env_conf = {
    "scenario": "cim",
    "topology": "toy.4p_ssdd_l0.0",
    "durations": 560
}

port_attributes = ["empty", "full", "on_shipper", "on_consignee", "booking", "shortage", "fulfillment"]
vessel_attributes = ["empty", "full", "remaining_space"]

state_shaping_conf = {
    "look_back": 7,
    "max_ports_downstream": 2
}

action_shaping_conf = {
    "action_space": [(i - 10) / 10 for i in range(21)],
    "finite_vessel_space": True,
    "has_early_discharge": True
}

reward_shaping_conf = {
    "time_window": 99,
    "fulfillment_factor": 1.0,
    "shortage_factor": 1.0,
    "time_decay": 0.97
}

## Environment Sampler

In [None]:
import numpy as np
from maro.rl.learning import AbsEnvSampler
from maro.simulator.scenarios.cim.common import Action, ActionType


class CIMEnvSampler(AbsEnvSampler):
    def get_state(self, tick=None):
        if tick is None:
            tick = self.env.tick
        vessel_snapshots, port_snapshots = self.env.snapshot_list["vessels"], self.env.snapshot_list["ports"]
        port_idx, vessel_idx = self.event.port_idx, self.event.vessel_idx
        ticks = [max(0, tick - rt) for rt in range(state_shaping_conf["look_back"] - 1)]
        future_port_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int') 
        state = np.concatenate([
            port_snapshots[ticks : [port_idx] + list(future_port_list) : port_attributes],
            vessel_snapshots[tick : vessel_idx : vessel_attributes]
        ])
        return {port_idx: state}

    def get_env_actions(self, action_by_agent):
        action_space = action_shaping_conf["action_space"]
        finite_vsl_space = action_shaping_conf["finite_vessel_space"]
        has_early_discharge = action_shaping_conf["has_early_discharge"]

        port_idx, action = list(action_by_agent.items()).pop()
        vsl_idx, action_scope = self.event.vessel_idx, self.event.action_scope
        vsl_snapshots = self.env.snapshot_list["vessels"]
        vsl_space = vsl_snapshots[self.env.tick:vsl_idx:vessel_attributes][2] if finite_vsl_space else float("inf")

        model_action = action["action"] if isinstance(action, dict) else action    
        percent = abs(action_space[model_action])
        zero_action_idx = len(action_space) / 2  # index corresponding to value zero.
        if model_action < zero_action_idx:
            action_type = ActionType.LOAD
            actual_action = min(round(percent * action_scope.load), vsl_space)
        elif model_action > zero_action_idx:
            action_type = ActionType.DISCHARGE
            early_discharge = vsl_snapshots[self.env.tick:vsl_idx:"early_discharge"][0] if has_early_discharge else 0
            plan_action = percent * (action_scope.discharge + early_discharge) - early_discharge
            actual_action = round(plan_action) if plan_action > 0 else round(percent * action_scope.discharge)
        else:
            actual_action, action_type = 0, None

        return [Action(port_idx=port_idx, vessel_idx=vsl_idx, quantity=actual_action, action_type=action_type)]

    def get_reward(self, actions, tick):
        """Delayed reward evaluation."""
        start_tick = tick + 1
        ticks = list(range(start_tick, start_tick + reward_shaping_conf["time_window"]))

        # Get the ports that took actions at the given tick
        ports = [action.port_idx for action in actions]
        port_snapshots = self.env.snapshot_list["ports"]
        future_fulfillment = port_snapshots[ticks:ports:"fulfillment"].reshape(len(ticks), -1)
        future_shortage = port_snapshots[ticks:ports:"shortage"].reshape(len(ticks), -1)

        decay_list = [reward_shaping_conf["time_decay"] ** i for i in range(reward_shaping_conf["time_window"])]
        rewards = np.float32(
            reward_shaping_conf["fulfillment_factor"] * np.dot(future_fulfillment.T, decay_list)
            - reward_shaping_conf["shortage_factor"] * np.dot(future_shortage.T, decay_list)
        )
        return {agent_id: reward for agent_id, reward in zip(ports, rewards)}

## [Policies](https://maro.readthedocs.io/en/latest/key_components/rl_toolkit.html#agent)

The out-of-the-box ActorCritic is used as our agent.

In [None]:
import torch
from torch.optim import Adam, RMSprop

from maro.rl.exploration import MultiLinearExplorationScheduler, epsilon_greedy
from maro.rl.modeling import DiscreteACNet, FullyConnected
from maro.rl.policy import ActorCritic

# We consider the port in question as well as two downstream ports.
# We consider the states of these ports over the past 7 days plus the current day, hence the factor 8.
# obtain state dimension from a temporary env_wrapper instance
state_dim = (
    (state_shaping_conf["look_back"] + 1) * (state_shaping_conf["max_ports_downstream"] + 1) * len(port_attributes)
    + len(vessel_attributes)
)

# AC settings
actor_net_conf = {
    "input_dim": state_dim,
    "hidden_dims": [256, 128, 64],
    "output_dim": len(action_shaping_conf["action_space"]),
    "activation": torch.nn.Tanh,
    "softmax": True,
    "batch_norm": False,
    "head": True
}

critic_net_conf = {
    "input_dim": state_dim,
    "hidden_dims": [256, 128, 64],
    "output_dim": 1,
    "activation": torch.nn.LeakyReLU,
    "softmax": False,
    "batch_norm": True,
    "head": True
}

actor_optim_conf = (Adam, {"lr": 0.001})
critic_optim_conf = (RMSprop, {"lr": 0.001})

ac_conf = {
    "reward_discount": .0,
    "grad_iters": 10,
    "critic_loss_cls": torch.nn.SmoothL1Loss,
    "min_logp": None,
    "critic_loss_coeff": 0.1,
    "entropy_coeff": 0.01,
    # "clip_ratio": 0.8   # for PPO
    "lam": .0,
    "get_loss_on_rollout": False
}


class MyACNet(DiscreteACNet):
    def __init__(self):
        super().__init__()
        self.actor = FullyConnected(**actor_net_conf)
        self.critic = FullyConnected(**critic_net_conf)
        self.actor_optim = actor_optim_conf[0](self.actor.parameters(), **actor_optim_conf[1])
        self.critic_optim = critic_optim_conf[0](self.critic.parameters(), **critic_optim_conf[1])

    @property
    def input_dim(self):
        return state_dim

    @property
    def num_actions(self):
        return q_net_conf["output_dim"]

    def forward(self, states, actor: bool = True, critic: bool = True):
        return (self.actor(states) if actor else None), (self.critic(states) if critic else None)

    def step(self, loss):
        self.actor_optim.zero_grad()
        self.critic_optim.zero_grad()
        loss.backward()
        self.actor_optim.step()
        self.critic_optim.step()

    def get_gradients(self, loss):
        self.actor_optim.zero_grad()
        self.critic_optim.zero_grad()
        loss.backward()
        return {name: param.grad for name, param in self.named_parameters()}

    def apply_gradients(self, grad):
        for name, param in self.named_parameters():
            param.grad = grad[name]

        self.actor_optim.step()
        self.critic_optim.step()


policy_func_dict = {f"ac.{i}": lambda name: ActorCritic(name, MyACNet(), **ac_conf) for i in range(4)}

## Training

This code cell demonstrates a typical single-threaded training workflow.

In [None]:
from maro.simulator import Env
from maro.rl.learning import simple_learner
from maro.utils import set_seeds

# post-env-step callback
def post_step(env, tracker, state, action, env_action, reward, tick):
    tracker["env_metric"] = env.metrics

def get_env_sampler():
    return CIMEnvSampler(
        get_env=lambda: Env(**env_conf),
        get_policy_func_dict=policy_func_dict,
        agent2policy={agent: f"ac.{agent}" for agent in Env(**env_conf).agent_idx_list},
        reward_eval_delay=reward_shaping_conf["time_window"],
        post_step=post_step
    )

# post-episode callback
def post_collect(trackers, ep, segment):
    # print the env metric from each rollout worker
    for tracker in trackers:
        print(f"env summary (episode {ep}, segment {segment}): {tracker['env_metric']}")

    # print the average env metric
    if len(trackers) > 1:
        metric_keys, num_trackers = trackers[0]["env_metric"].keys(), len(trackers)
        avg_metric = {key: sum(tr["env_metric"][key] for tr in trackers) / num_trackers for key in metric_keys}
        print(f"average env summary (episode {ep}, segment {segment}): {avg_metric}")

# post-evaluation callback
def post_evaluate(trackers, ep):
    # print the env metric from each rollout worker
    for tracker in trackers:
        print(f"env summary (evaluation episode {ep}): {tracker['env_metric']}")

    # print the average env metric
    if len(trackers) > 1:
        metric_keys, num_trackers = trackers[0]["env_metric"].keys(), len(trackers)
        avg_metric = {key: sum(tr["env_metric"][key] for tr in trackers) / num_trackers for key in metric_keys}
        simulation_logger.info(f"average env summary (evaluation episode {ep}): {avg_metric}")


set_seeds(1024)  # for reproducibility
simple_learner(get_env_sampler, num_episodes=50, post_collect=post_collect, post_evaluate=post_evaluate)