# Quick Start

This notebook demonstrates how to use MARO's reinforcement learning (RL) toolkit to solve the container inventory management ([CIM](https://maro.readthedocs.io/en/latest/scenarios/container_inventory_management.html)) problem. It is formalized as a multi-agent reinforcement learning problem, where each port acts as a decision agent. When a vessel arrives at a port, these agents must take actions by transfering a certain amount of containers to / from the vessel. The objective is for the agents to learn policies that minimize the cumulative container shortage. 

In [1]:
import numpy as np

# Common info
common_config = {
    "port_attributes": ["empty", "full", "on_shipper", "on_consignee", "booking", "shortage", "fulfillment"],
    "vessel_attributes": ["empty", "full", "remaining_space"],
    "action_space": list(np.linspace(-1.0, 1.0, 21)),
    # Parameters for computing states
    "look_back": 7,
    "max_ports_downstream": 2,
    # Parameters for computing actions
    "finite_vessel_space": True,
    "has_early_discharge": True,
    # Parameters for computing rewards
    "reward_time_window": 99,
    "fulfillment_factor": 1.0,
    "shortage_factor": 1.0,
    "time_decay": 0.97
}

## Shaping

In [2]:
from collections import defaultdict
import numpy as np
from maro.rl import Trajectory
from maro.simulator.scenarios.cim.common import Action, ActionType


class CIMTrajectory(Trajectory):
    def __init__(
        self, env, *, port_attributes, vessel_attributes, action_space, look_back, max_ports_downstream,
        reward_time_window, fulfillment_factor, shortage_factor, time_decay,
        finite_vessel_space=True, has_early_discharge=True 
    ):
        super().__init__(env)
        self.port_attributes = port_attributes
        self.vessel_attributes = vessel_attributes
        self.action_space = action_space
        self.look_back = look_back
        self.max_ports_downstream = max_ports_downstream
        self.reward_time_window = reward_time_window
        self.fulfillment_factor = fulfillment_factor
        self.shortage_factor = shortage_factor
        self.time_decay = time_decay
        self.finite_vessel_space = finite_vessel_space
        self.has_early_discharge = has_early_discharge

    def get_state(self, event):
        vessel_snapshots, port_snapshots = self.env.snapshot_list["vessels"], self.env.snapshot_list["ports"]
        tick, port_idx, vessel_idx = event.tick, event.port_idx, event.vessel_idx
        ticks = [max(0, tick - rt) for rt in range(self.look_back - 1)]
        future_port_idx_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int')
        port_features = port_snapshots[ticks: [port_idx] + list(future_port_idx_list): self.port_attributes]
        vessel_features = vessel_snapshots[tick: vessel_idx: self.vessel_attributes]
        return {port_idx: np.concatenate((port_features, vessel_features))}

    def get_action(self, action_by_agent, event):
        vessel_snapshots = self.env.snapshot_list["vessels"]
        action_info = list(action_by_agent.values())[0]
        model_action = action_info[0] if isinstance(action_info, tuple) else action_info
        scope, tick, port, vessel = event.action_scope, event.tick, event.port_idx, event.vessel_idx
        zero_action_idx = len(self.action_space) / 2  # index corresponding to value zero.
        vessel_space = vessel_snapshots[tick:vessel:self.vessel_attributes][2] if self.finite_vessel_space else float("inf")
        early_discharge = vessel_snapshots[tick:vessel:"early_discharge"][0] if self.has_early_discharge else 0
        percent = abs(self.action_space[model_action])

        if model_action < zero_action_idx:
            action_type = ActionType.LOAD
            actual_action = min(round(percent * scope.load), vessel_space)
        elif model_action > zero_action_idx:
            action_type = ActionType.DISCHARGE
            plan_action = percent * (scope.discharge + early_discharge) - early_discharge
            actual_action = round(plan_action) if plan_action > 0 else round(percent * scope.discharge)
        else:
            actual_action, action_type = 0, None

        return {port: Action(vessel, port, actual_action, action_type)}

    def get_offline_reward(self, event):
        port_snapshots = self.env.snapshot_list["ports"]
        start_tick = event.tick + 1
        ticks = list(range(start_tick, start_tick + self.reward_time_window))

        future_fulfillment = port_snapshots[ticks::"fulfillment"]
        future_shortage = port_snapshots[ticks::"shortage"]
        decay_list = [
            self.time_decay ** i for i in range(self.reward_time_window)
            for _ in range(future_fulfillment.shape[0] // self.reward_time_window)
        ]

        tot_fulfillment = np.dot(future_fulfillment, decay_list)
        tot_shortage = np.dot(future_shortage, decay_list)

        return np.float32(self.fulfillment_factor * tot_fulfillment - self.shortage_factor * tot_shortage)

    def on_env_feedback(self, event, state_by_agent, action_by_agent, reward):
        self.trajectory["event"].append(event)
        self.trajectory["state"].append(state_by_agent)
        self.trajectory["action"].append(action_by_agent)
    
    def on_finish(self):
        training_data = {}
        for event, state, action in zip(self.trajectory["event"], self.trajectory["state"], self.trajectory["action"]):
            agent_id = list(state.keys())[0]
            data = training_data.setdefault(agent_id, {"args": [[] for _ in range(4)]})
            data["args"][0].append(state[agent_id])  # state
            data["args"][1].append(action[agent_id][0])  # action
            data["args"][2].append(action[agent_id][1])  # log_p
            data["args"][3].append(self.get_offline_reward(event))  # reward

        for agent_id in training_data:
            training_data[agent_id]["args"] = [
                np.asarray(vals, dtype=np.float32 if i == 3 else None)
                for i, vals in enumerate(training_data[agent_id]["args"])
            ]

        return training_data

## [Agent](https://maro.readthedocs.io/en/latest/key_components/rl_toolkit.html#agent)

The out-of-the-box ActorCritic is used as our agent.

In [3]:
import torch.nn as nn
from torch.optim import Adam, RMSprop

from maro.rl import ActorCritic, ActorCriticConfig, FullyConnectedBlock, OptimOption, SimpleMultiHeadModel

# We consider the port in question as well as two downstream ports.
# We consider the states of these ports over the past 7 days plus the current day, hence the factor 8.
input_dim = (
    (common_config["look_back"] + 1) *
    (common_config["max_ports_downstream"] + 1) *
    len(common_config["port_attributes"]) +
    len(common_config["vessel_attributes"])
)

agent_config = {
    "model": {
        "actor": {
            "input_dim": input_dim,
            "output_dim": len(common_config["action_space"]),
            "hidden_dims": [256, 128, 64],
            "activation": nn.Tanh,
            "softmax": True,
            "batch_norm": False,
            "head": True
        },
        "critic": {
            "input_dim": input_dim,
            "output_dim": 1,
            "hidden_dims": [256, 128, 64],
            "activation": nn.LeakyReLU,
            "softmax": False,
            "batch_norm": True,
            "head": True
        }
    },
    "optimization": {
        "actor": OptimOption(optim_cls=Adam, optim_params={"lr": 0.001}),
        "critic": OptimOption(optim_cls=RMSprop, optim_params={"lr": 0.001})
    },
    "hyper_params": {
        "reward_discount": .0,
        "critic_loss_func": nn.SmoothL1Loss(),
        "train_iters": 10,
        "actor_loss_coefficient": 0.1,  # loss = actor_loss_coefficient * actor_loss + critic_loss
        "k": 1,  # for k-step return
        "lam": 0.0  # lambda return coefficient
    }
}

def get_ac_agent():
    actor_net = FullyConnectedBlock(**agent_config["model"]["actor"])
    critic_net = FullyConnectedBlock(**agent_config["model"]["critic"])
    ac_model = SimpleMultiHeadModel(
        {"actor": actor_net, "critic": critic_net}, optim_option=agent_config["optimization"],
    )
    return ActorCritic(ac_model, ActorCriticConfig(**agent_config["hyper_params"]))

## Training

This code cell demonstrates a typical single-threaded training workflow.

In [4]:
from maro.simulator import Env
from maro.rl import Actor, MultiAgentWrapper, OnPolicyLearner
from maro.utils import set_seeds

set_seeds(1024)  # for reproducibility
env = Env("cim", "toy.4p_ssdd_l0.0", durations=1120)
agent = MultiAgentWrapper({name: get_ac_agent() for name in env.agent_idx_list})
actor = Actor(env, agent, CIMTrajectory, trajectory_kwargs=common_config)
learner = OnPolicyLearner(actor, 40)  # 40 episodes
learner.run()

14:54:17 | LEARNER | INFO | ep-0: {'order_requirements': 2240000, 'container_shortage': 1422736, 'operation_number': 4220466}
14:54:19 | LEARNER | INFO | Agent learning finished
14:54:23 | LEARNER | INFO | ep-1: {'order_requirements': 2240000, 'container_shortage': 1330641, 'operation_number': 3919970}
14:54:24 | LEARNER | INFO | Agent learning finished
14:54:29 | LEARNER | INFO | ep-2: {'order_requirements': 2240000, 'container_shortage': 996878, 'operation_number': 3226186}
14:54:30 | LEARNER | INFO | Agent learning finished
14:54:34 | LEARNER | INFO | ep-3: {'order_requirements': 2240000, 'container_shortage': 703662, 'operation_number': 3608511}
14:54:36 | LEARNER | INFO | Agent learning finished
14:54:40 | LEARNER | INFO | ep-4: {'order_requirements': 2240000, 'container_shortage': 601934, 'operation_number': 3579281}
14:54:41 | LEARNER | INFO | Agent learning finished
14:54:45 | LEARNER | INFO | ep-5: {'order_requirements': 2240000, 'container_shortage': 629344, 'operation_number