# Quick Start

This notebook demonstrates how to use MARO's reinforcement learning (RL) toolkit to solve the container inventory management ([CIM](https://maro.readthedocs.io/en/latest/scenarios/container_inventory_management.html)) problem. It is formalized as a multi-agent reinforcement learning problem, where each port acts as a decision agent. The agents take actions independently, e.g., loading containers to vessels or discharging containers from vessels.   

## [State Shaper](https://maro.readthedocs.io/en/latest/key_components/rl_toolkit.html#shapers)

State shaper converts the environment observation to the model input state which includes temporal and spatial information. For this scenario, the model input state includes: 

- Temporal information, including the past week's information of ports and vessels, such as shortage on port and remaining space on vessel. 

- Spatial information, it including the related downstream port features.    

In [None]:
import numpy as np
from maro.rl import StateShaper


PORT_ATTRIBUTES = ["empty", "full", "on_shipper", "on_consignee", "booking", "shortage", "fulfillment"]
VESSEL_ATTRIBUTES = ["empty", "full", "remaining_space"]


class CIMStateShaper(StateShaper):
    def __init__(self, *, look_back, max_ports_downstream):
        super().__init__()
        self._look_back = look_back
        self._max_ports_downstream = max_ports_downstream
        self._dim = (look_back + 1) * (max_ports_downstream + 1) * len(PORT_ATTRIBUTES) + len(VESSEL_ATTRIBUTES)

    def __call__(self, decision_event, snapshot_list):
        tick, port_idx, vessel_idx = decision_event.tick, decision_event.port_idx, decision_event.vessel_idx
        ticks = [tick - rt for rt in range(self._look_back - 1)]
        future_port_idx_list = snapshot_list["vessels"][tick: vessel_idx: 'future_stop_list'].astype('int')
        port_features = snapshot_list["ports"][ticks: [port_idx] + list(future_port_idx_list): PORT_ATTRIBUTES]
        vessel_features = snapshot_list["vessels"][tick: vessel_idx: VESSEL_ATTRIBUTES]
        state = np.concatenate((port_features, vessel_features))
        return str(port_idx), state

    @property
    def dim(self):
        return self._dim
    
# Create a state shaper
state_shaper = CIMStateShaper(look_back=7, max_ports_downstream=2)

## [Action Shaper](https://maro.readthedocs.io/en/latest/key_components/rl_toolkit.html#shapers)

Action shaper is used to convert an agent's model output to an environment executable action. For this specific scenario, the output is a discrete index that corresponds to a percentage indicating the fraction of containers to be loaded to or discharged from the arriving vessel.

In [None]:
from maro.rl import ActionShaper
from maro.simulator.scenarios.cim.common import Action


class CIMActionShaper(ActionShaper):
    def __init__(self, action_space):
        super().__init__()
        self._action_space = action_space
        self._zero_action_index = action_space.index(0)

    def __call__(self, model_action, decision_event, snapshot_list):
        assert 0 <= model_action < len(self._action_space)
        
        scope = decision_event.action_scope
        tick = decision_event.tick
        port_idx = decision_event.port_idx
        vessel_idx = decision_event.vessel_idx
        port_empty = snapshot_list["ports"][tick: port_idx: ["empty", "full", "on_shipper", "on_consignee"]][0]
        vessel_remaining_space = snapshot_list["vessels"][tick: vessel_idx: ["empty", "full", "remaining_space"]][2]
        early_discharge = snapshot_list["vessels"][tick:vessel_idx: "early_discharge"][0]
     
        if model_action < self._zero_action_index:
            # The number of loaded containers must be less than the vessel's remaining space.
            actual_action = max(round(self._action_space[model_action] * port_empty), -vessel_remaining_space)
        elif model_action > self._zero_action_index:
            # In the case of an early discharge event, we need to subtract the early discharge amount from the expected 
            # discharge quote.   
            plan_action = self._action_space[model_action] * (scope.discharge + early_discharge) - early_discharge
            actual_action = round(plan_action) if plan_action > 0 else round(self._action_space[model_action] * scope.discharge)
        else:
            actual_action = 0

        return Action(vessel_idx, port_idx, actual_action)
    
# Create an action shaper
NUM_ACTIONS = 21
action_shaper = CIMActionShaper(action_space=list(np.linspace(-1.0, 1.0, NUM_ACTIONS)))

## [Experience Shaper](https://maro.readthedocs.io/en/latest/key_components/rl_toolkit.html#shapers)

Experience shaper is used to convert an episode trajectory to trainable experiences for RL agents. For this specific scenario, the reward is a linear combination of fulfillment and shortage in a limited time window.

In [None]:
from collections import defaultdict

from maro.rl import ExperienceShaper


class TruncatedExperienceShaper(ExperienceShaper):
    def __init__(
        self, *, time_window: int, time_decay_factor: float, fulfillment_factor: float, shortage_factor: float
    ):
        super().__init__(reward_func=None)
        self._time_window = time_window
        self._time_decay_factor = time_decay_factor
        self._fulfillment_factor = fulfillment_factor
        self._shortage_factor = shortage_factor

    def __call__(self, trajectory, snapshot_list):
        experiences_by_agent = {}
        for i in range(len(trajectory) - 1):
            transition = trajectory[i]
            agent_id = transition["agent_id"]
            if agent_id not in experiences_by_agent:
                experiences_by_agent[agent_id] = defaultdict(list)
            experiences = experiences_by_agent[agent_id]
            experiences["state"].append(transition["state"])
            experiences["action"].append(transition["action"])
            experiences["reward"].append(self._compute_reward(transition["event"], snapshot_list))
            experiences["next_state"].append(trajectory[i + 1]["state"])

        return experiences_by_agent

    def _compute_reward(self, decision_event, snapshot_list):
        start_tick = decision_event.tick + 1
        end_tick = decision_event.tick + self._time_window
        ticks = list(range(start_tick, end_tick))

        # calculate tc reward
        future_fulfillment = snapshot_list["ports"][ticks::"fulfillment"]
        future_shortage = snapshot_list["ports"][ticks::"shortage"]
        decay_list = [
            self._time_decay_factor ** i for i in range(end_tick - start_tick)
            for _ in range(future_fulfillment.shape[0] // (end_tick - start_tick))
        ]

        tot_fulfillment = np.dot(future_fulfillment, decay_list)
        tot_shortage = np.dot(future_shortage, decay_list)

        return np.float32(self._fulfillment_factor * tot_fulfillment - self._shortage_factor * tot_shortage)
    
# Create an experience shaper
experience_shaper = TruncatedExperienceShaper(time_window=100, fulfillment_factor=1.0, shortage_factor=1.0, time_decay_factor=0.97)

## [Agent](https://maro.readthedocs.io/en/latest/key_components/rl_toolkit.html#agent)

For this scenario, the agent is the abstraction of a port. We choose DQN as our underlying learning algorithm with a TD-error-based sampling mechanism.  

In [None]:
from maro.rl import AbsAgent, ColumnBasedStore


class CIMAgent(AbsAgent):
    def __init__(self, name, algorithm, experience_pool: ColumnBasedStore, min_experiences_to_train, num_batches, batch_size):
        super().__init__(name, algorithm, experience_pool)
        self._min_experiences_to_train = min_experiences_to_train
        self._num_batches = num_batches
        self._batch_size = batch_size

    def train(self):
        if len(self._experience_pool) < self._min_experiences_to_train:
            return

        for _ in range(self._num_batches):
            indexes, sample = self._experience_pool.sample_by_key("loss", self._batch_size)
            state = np.asarray(sample["state"])
            action = np.asarray(sample["action"])
            reward = np.asarray(sample["reward"])
            next_state = np.asarray(sample["next_state"])
            loss = self._algorithm.train(state, action, reward, next_state)
            self._experience_pool.update(indexes, {"loss": loss})

## [Agent Manager](https://maro.readthedocs.io/en/latest/key_components/rl_toolkit.html#agent-manager)

The complexities of the environment can be isolated from the learning algorithm by using an AgentManager to manage individual agents. We define a function to create the agents and an agent manager class that implements the ``train`` method where the newly obtained experiences are stored in the agents' experience pools before training, in accordance with the DQN algorithm.

In [None]:
import io
import yaml

import torch.nn as nn
from torch.nn.functional import smooth_l1_loss
from torch.optim import RMSprop

from maro.rl import (
    ColumnBasedStore, DQN, DQNConfig, FullyConnectedBlock, LearningModuleManager, LearningModule, OptimizerOptions, SimpleAgentManager
)
from maro.utils import set_seeds


def create_dqn_agents(agent_id_list):
    set_seeds(1)  # for reproducibility
    agent_dict = {}
    for agent_id in agent_id_list:
        q_module = LearningModule(
            "q_value",
            [FullyConnectedBlock(
                input_dim=state_shaper.dim,
                hidden_dims=[256, 128, 64],
                output_dim=NUM_ACTIONS,
                activation=nn.LeakyReLU,
                is_head=True,
                batch_norm_enabled=True, 
                softmax_enabled=False,
                skip_connection_enabled=False,
                dropout_p=.0)
            ],
            optimizer_options=OptimizerOptions(cls=RMSprop, params={"lr": 0.05})
        )

        algorithm = DQN(
            model=LearningModuleManager(q_module),
            config=DQNConfig(
                reward_decay=.0, 
                target_update_frequency=5, 
                tau=0.1, 
                is_double=True, 
                per_sample_td_error_enabled=True,
                loss_cls=nn.SmoothL1Loss,
                num_actions=NUM_ACTIONS
            )
        )

        agent_dict[agent_id] = CIMAgent(
            agent_id, algorithm, ColumnBasedStore(), min_experiences_to_train=1024, num_batches=10, batch_size=128
        )

    return agent_dict


class DQNAgentManager(SimpleAgentManager):
    def train(self, experiences_by_agent, performance=None):
        self._assert_train_mode()

        # store experiences for each agent
        for agent_id, exp in experiences_by_agent.items():
            exp.update({"loss": [1e8] * len(list(exp.values())[0])})
            self.agent_dict[agent_id].store_experiences(exp)

        for agent in self.agent_dict.values():
            agent.train()

## Main Loop with [Actor and Learner](https://maro.readthedocs.io/en/latest/key_components/rl_toolkit.html#learner-and-actor)

This code cell demonstrates the typical workflow of a learning policy's interaction with a MARO environment. 

- Initialize an environment with specific scenario and topology parameters. 

- Define scenario-specific components, e.g. shapers. 

- Create agents and an agent manager. 

- Create an actor and a learner to start the training process in which the agent manager interacts with the environment for collecting experiences and updating policies. 

In [None]:
from maro.simulator import Env
from maro.rl import AgentManagerMode, SimpleActor, SimpleLearner, TwoPhaseLinearParameterScheduler
from maro.utils import LogFormat, Logger

# Step 1: initialize a CIM environment for a toy dataset. 
env = Env("cim", "toy.4p_ssdd_l0.0", durations=1120)
agent_id_list = [str(agent_id) for agent_id in env.agent_idx_list]

# Step 2: create DQN agents and an agent manager to manage them.
agent_manager = DQNAgentManager(
    name="cim_learner",
    mode=AgentManagerMode.TRAIN_INFERENCE,
    agent_dict=create_dqn_agents(agent_id_list),
    state_shaper=state_shaper,
    action_shaper=action_shaper,
    experience_shaper=experience_shaper
)

# Step 3: Create an actor and a learner to start the training process. 
max_episode = 100
scheduler = TwoPhaseLinearParameterScheduler(
    max_episode,
    parameter_names=["epsilon"],
    split_ep=50,
    start_values=0.4,
    mid_values=0.32,
    end_values=.0
)

actor = SimpleActor(env, agent_manager)
learner = SimpleLearner(
    agent_manager, actor, scheduler, 
    logger=Logger("single_host_cim_learner", format_=LogFormat.simple, auto_timestamp=False)
)

learner.learn()
learner.test()