diff --git a/examples/cim/dqn/components/experience_shaper.py b/examples/cim/dqn/components/experience_shaper.py index 2941f2159..4d82b5233 100644 --- a/examples/cim/dqn/components/experience_shaper.py +++ b/examples/cim/dqn/components/experience_shaper.py @@ -41,7 +41,7 @@ def _compute_reward(self, decision_event, snapshot_list): future_fulfillment = snapshot_list["ports"][ticks::"fulfillment"] future_shortage = snapshot_list["ports"][ticks::"shortage"] decay_list = [self._time_decay_factor ** i for i in range(end_tick - start_tick) - for _ in range(future_fulfillment.shape[0]//(end_tick-start_tick))] + for _ in range(future_fulfillment.shape[0] // (end_tick - start_tick))] tot_fulfillment = np.dot(future_fulfillment, decay_list) tot_shortage = np.dot(future_shortage, decay_list) diff --git a/examples/cim/dqn/single_process_launcher.py b/examples/cim/dqn/single_process_launcher.py index a906513ea..5e80df041 100644 --- a/examples/cim/dqn/single_process_launcher.py +++ b/examples/cim/dqn/single_process_launcher.py @@ -29,7 +29,8 @@ if config.experience_shaping.type == "truncated": experience_shaper = TruncatedExperienceShaper(**config.experience_shaping.truncated) else: - experience_shaper = KStepExperienceShaper(reward_func=lambda mt: 1-mt["container_shortage"]/mt["order_requirements"], + experience_shaper = KStepExperienceShaper(reward_func=lambda mt: 1 - mt["container_shortage"] / + mt["order_requirements"], **config.experience_shaping.k_step) exploration_config = {"epsilon_range_dict": {"_all_": config.exploration.epsilon_range}, diff --git a/examples/cim/gnn/action_shaper.py b/examples/cim/gnn/action_shaper.py new file mode 100644 index 000000000..9daf20eef --- /dev/null +++ b/examples/cim/gnn/action_shaper.py @@ -0,0 +1,37 @@ +from maro.rl import ActionShaper + + +class DiscreteActionShaper(ActionShaper): + """The shaping class to transform the action in [-1, 1] to actual repositioning function.""" + def __init__(self, action_dim): + super().__init__() + self._action_dim = action_dim + self._zero_action = self._action_dim // 2 + + def __call__(self, decision_event, model_action): + """Shaping the action in [-1,1] range to the actual repositioning function. + + This function maps integer model action within the range of [-A, A] to actual action. We define negative actual + action as discharge resource from vessel to port and positive action as upload from port to vessel, so the + upper bound and lower bound of actual action are the resource in dynamic and static node respectively. + + Args: + decision_event (Event): The decision event from the environment. + model_action (int): Output action, range A means the half of the agent output dim. + """ + env_action = 0 + model_action -= self._zero_action + + action_scope = decision_event.action_scope + + if model_action < 0: + # Discharge resource from dynamic node. + env_action = round(int(model_action) * 1.0 / self._zero_action * action_scope.load) + elif model_action == 0: + env_action = 0 + else: + # Load resource to dynamic node. + env_action = round(int(model_action) * 1.0 / self._zero_action * action_scope.discharge) + env_action = int(env_action) + + return env_action diff --git a/examples/cim/gnn/actor.py b/examples/cim/gnn/actor.py new file mode 100644 index 000000000..4c9215d8c --- /dev/null +++ b/examples/cim/gnn/actor.py @@ -0,0 +1,362 @@ +import ctypes +import multiprocessing +import os +import pickle +import time +from collections import OrderedDict +from multiprocessing import Process, Pipe + +import numpy as np +import torch + +from examples.cim.gnn.action_shaper import DiscreteActionShaper +from examples.cim.gnn.experience_shaper import ExperienceShaper +from examples.cim.gnn.state_shaper import GNNStateShaper +from examples.cim.gnn.shared_structure import SharedStructure +from examples.cim.gnn.utils import fix_seed, gnn_union +from maro.rl import AbsActor +from maro.simulator import Env +from maro.simulator.scenarios.cim.common import Action + + +def organize_exp_list(experience_collections: dict, idx_mapping: dict): + """The function assemble the experience from multiple processes into a dictionary. + + Args: + experience_collections (dict): It stores the experience in all agents. The structure is the same as what is + defined in the SharedStructure in the ParallelActor except additional key for experience length. For + example: + + { + "len": numpy.array, + "s": { + "v": numpy.array, + "p": numpy.array, + } + "a": numpy.array, + "R": numpy.array, + "s_": { + "v": numpy.array, + "p": numpy.array, + } + } + + Note that the experience from different agents are stored in the same batch in a sequential way. For + example, if agent x starts at b_x in batch index and the experience is l_x length long, the range [b_x, + l_x) in the batch is the experience of agent x. + + idx_mapping (dict): The key is the name of each agent and the value is the starting index, e.g., b_x, of the + storage space where the experience of the agent is stored. + """ + result = {} + tmpi = 0 + for code, idx in idx_mapping.items(): + exp_len = experience_collections["len"][0][tmpi] + + s = organize_obs(experience_collections["s"], idx, exp_len) + s_ = organize_obs(experience_collections["s_"], idx, exp_len) + R = experience_collections["R"][idx: idx + exp_len] + R = R.reshape(-1, *R.shape[2:]) + a = experience_collections["a"][idx: idx + exp_len] + a = a.reshape(-1, *a.shape[2:]) + + result[code] = { + "R": R, + "a": a, + "s": s, + "s_": s_, + "len": a.shape[0] + } + tmpi += 1 + return result + + +def organize_obs(obs, idx, exp_len): + """Helper function to transform the observation from multiple processes to a unified dictionary.""" + tick_buffer, _, para_cnt, v_cnt, v_dim = obs["v"].shape + _, _, _, p_cnt, p_dim = obs["p"].shape + batch = exp_len * para_cnt + # v: tick_buffer, seq_len, parallel_cnt, v_cnt, v_dim --> (tick_buffer, cnt, v_cnt, v_dim) + v = obs["v"][:, idx: idx + exp_len] + v = v.reshape(tick_buffer, batch, v_cnt, v_dim) + p = obs["p"][:, idx: idx + exp_len] + p = p.reshape(tick_buffer, batch, p_cnt, p_dim) + # vo: seq_len * parallel_cnt * v_cnt * p_cnt* --> cnt * v_cnt * p_cnt* + vo = obs["vo"][idx: idx + exp_len] + vo = vo.reshape(batch, v_cnt, vo.shape[-1]) + po = obs["po"][idx: idx + exp_len] + po = po.reshape(batch, p_cnt, po.shape[-1]) + vedge = obs["vedge"][idx: idx + exp_len] + vedge = vedge.reshape(batch, v_cnt, vedge.shape[-2], vedge.shape[-1]) + pedge = obs["pedge"][idx: idx + exp_len] + pedge = pedge.reshape(batch, p_cnt, pedge.shape[-2], pedge.shape[-1]) + ppedge = obs["ppedge"][idx: idx + exp_len] + ppedge = ppedge.reshape(batch, p_cnt, ppedge.shape[-2], ppedge.shape[-1]) + + # mask: (seq_len, parallel_cnt, tick_buffer) + mask = obs["mask"][idx: idx + exp_len].reshape(batch, tick_buffer) + + return {"v": v, "p": p, "vo": vo, "po": po, "pedge": pedge, "vedge": vedge, "ppedge": ppedge, "mask": mask} + + +def single_player_worker(index, config, exp_idx_mapping, pipe, action_io, exp_output): + """The A2C worker function to collect experience. + + Args: + index (int): The process index counted from 0. + config (dict): It is a dottable dictionary that stores the configuration of the simulation, state_shaper and + postprocessing shaper. + exp_idx_mapping (dict): The key is agent code and the value is the starting index where the experience is stored + in the experience batch. + pipe (Pipe): The pipe instance for communication with the main process. + action_io (SharedStructure): The shared memory to hold the state information that the main process uses to + generate an action. + exp_output (SharedStructure): The shared memory to transfer the experience list to the main process. + """ + env = Env(**config.env.param) + fix_seed(env, config.env.seed) + static_code_list, dynamic_code_list = list(env.summary["node_mapping"]["ports"].values()), \ + list(env.summary["node_mapping"]["vessels"].values()) + # Create gnn_state_shaper without consuming any resources. + + gnn_state_shaper = GNNStateShaper( + static_code_list, dynamic_code_list, config.env.param.durations, config.model.feature, + tick_buffer=config.model.tick_buffer, max_value=env.configs["total_containers"]) + gnn_state_shaper.compute_static_graph_structure(env) + + action_io_np = action_io.structuralize() + + action_shaper = DiscreteActionShaper(config.model.action_dim) + exp_shaper = ExperienceShaper( + static_code_list, dynamic_code_list, config.env.param.durations, gnn_state_shaper, + scale_factor=config.env.return_scaler, time_slot=config.training.td_steps, + discount_factor=config.training.gamma, idx=index, shared_storage=exp_output.structuralize(), + exp_idx_mapping=exp_idx_mapping) + + i = 0 + while pipe.recv() == "reset": + env.reset() + r, decision_event, is_done = env.step(None) + + j = 0 + logs = [] + while not is_done: + model_input = gnn_state_shaper(decision_event, env.snapshot_list) + action_io_np["v"][:, index] = model_input["v"] + action_io_np["p"][:, index] = model_input["p"] + action_io_np["vo"][index] = model_input["vo"] + action_io_np["po"][index] = model_input["po"] + action_io_np["vedge"][index] = model_input["vedge"] + action_io_np["pedge"][index] = model_input["pedge"] + action_io_np["ppedge"][index] = model_input["ppedge"] + action_io_np["mask"][index] = model_input["mask"] + action_io_np["pid"][index] = decision_event.port_idx + action_io_np["vid"][index] = decision_event.vessel_idx + pipe.send("features") + model_action = pipe.recv() + env_action = action_shaper(decision_event, model_action) + exp_shaper.record(decision_event=decision_event, model_action=model_action, model_input=model_input) + logs.append([ + index, decision_event.tick, decision_event.port_idx, decision_event.vessel_idx, model_action, + env_action, decision_event.action_scope.load, decision_event.action_scope.discharge]) + action = Action(decision_event.vessel_idx, decision_event.port_idx, env_action) + r, decision_event, is_done = env.step(action) + j += 1 + action_io_np["sh"][index] = compute_shortage(env.snapshot_list, config.env.param.durations, static_code_list) + i += 1 + pipe.send("done") + gnn_state_shaper.end_ep_callback(env.snapshot_list) + # Organize and synchronize exp to shared memory. + exp_shaper(env.snapshot_list) + exp_shaper.reset() + logs = np.array(logs, dtype=np.float) + pipe.send(logs) + + +def compute_shortage(snapshot_list, max_tick, static_code_list): + """Helper function to compute the shortage after a episode end.""" + return np.sum(snapshot_list["ports"][max_tick - 1: static_code_list: "acc_shortage"]) + + +class ParallelActor(AbsActor): + def __init__(self, config, demo_env, gnn_state_shaper, agent_manager, logger): + """A2C rollout class. + + This implements the synchronized A2C structure. Multiple processes are created to simulate and collect + experience where only CPU is needed and whenever an action is required, they notify the main process and the + main process will do the batch action inference with GPU. + + Args: + config (dict): The configuration to run the simulation. + demo_env (maro.simulator.Env): To get configuration information such as the amount of vessels and ports as + well as the topology of the environment, the example environment is needed. + gnn_state_shaper (AbsShaper): The state shaper instance to extract graph information from the state of + the environment. + agent_manager (AbsAgentManger): The agent manager instance to do the action inference in batch. + logger: The logger instance to log information during the rollout. + + """ + super().__init__(demo_env, agent_manager) + multiprocessing.set_start_method("spawn", True) + self._logger = logger + self.config = config + + self._static_node_mapping = demo_env.summary["node_mapping"]["ports"] + self._dynamic_node_mapping = demo_env.summary["node_mapping"]["vessels"] + self._gnn_state_shaper = gnn_state_shaper + self.device = torch.device(config.training.device) + + self.parallel_cnt = config.training.parallel_cnt + self.log_header = [f"sh_{i}" for i in range(self.parallel_cnt)] + + tick_buffer = config.model.tick_buffer + + v_dim, vedge_dim, v_cnt = self._gnn_state_shaper.get_input_dim("v"), \ + self._gnn_state_shaper.get_input_dim("vedge"), len(self._dynamic_node_mapping) + p_dim, pedge_dim, p_cnt = self._gnn_state_shaper.get_input_dim("p"), \ + self._gnn_state_shaper.get_input_dim("pedge"), len(self._static_node_mapping) + + self.pipes = [Pipe() for i in range(self.parallel_cnt)] + + action_io_structure = { + "p": ((tick_buffer, self.parallel_cnt, p_cnt, p_dim), ctypes.c_float), + "v": ((tick_buffer, self.parallel_cnt, v_cnt, v_dim), ctypes.c_float), + "po": ((self.parallel_cnt, p_cnt, v_cnt), ctypes.c_long), + "vo": ((self.parallel_cnt, v_cnt, p_cnt), ctypes.c_long), + "vedge": ((self.parallel_cnt, v_cnt, p_cnt, vedge_dim), ctypes.c_float), + "pedge": ((self.parallel_cnt, p_cnt, v_cnt, vedge_dim), ctypes.c_float), + "ppedge": ((self.parallel_cnt, p_cnt, p_cnt, pedge_dim), ctypes.c_float), + "mask": ((self.parallel_cnt, tick_buffer), ctypes.c_bool), + "sh": ((self.parallel_cnt, ), ctypes.c_long), + "pid": ((self.parallel_cnt, ), ctypes.c_long), + "vid": ((self.parallel_cnt, ), ctypes.c_long) + } + self.action_io = SharedStructure(action_io_structure) + self.action_io_np = self.action_io.structuralize() + + tot_exp_len = sum(config.env.exp_per_ep.values()) + + exp_output_structure = { + "s": { + "v": ((tick_buffer, tot_exp_len, self.parallel_cnt, v_cnt, v_dim), ctypes.c_float), + "p": ((tick_buffer, tot_exp_len, self.parallel_cnt, p_cnt, p_dim), ctypes.c_float), + "vo": ((tot_exp_len, self.parallel_cnt, v_cnt, p_cnt), ctypes.c_long), + "po": ((tot_exp_len, self.parallel_cnt, p_cnt, v_cnt), ctypes.c_long), + "vedge": ((tot_exp_len, self.parallel_cnt, v_cnt, p_cnt, vedge_dim), ctypes.c_float), + "pedge": ((tot_exp_len, self.parallel_cnt, p_cnt, v_cnt, vedge_dim), ctypes.c_float), + "ppedge": ((tot_exp_len, self.parallel_cnt, p_cnt, p_cnt, pedge_dim), ctypes.c_float), + "mask": ((tot_exp_len, self.parallel_cnt, tick_buffer), ctypes.c_bool) + }, + "s_": { + "v": ((tick_buffer, tot_exp_len, self.parallel_cnt, v_cnt, v_dim), ctypes.c_float), + "p": ((tick_buffer, tot_exp_len, self.parallel_cnt, p_cnt, p_dim), ctypes.c_float), + "vo": ((tot_exp_len, self.parallel_cnt, v_cnt, p_cnt), ctypes.c_long), + "po": ((tot_exp_len, self.parallel_cnt, p_cnt, v_cnt), ctypes.c_long), + "vedge": ((tot_exp_len, self.parallel_cnt, v_cnt, p_cnt, vedge_dim), ctypes.c_float), + "pedge": ((tot_exp_len, self.parallel_cnt, p_cnt, v_cnt, vedge_dim), ctypes.c_float), + "ppedge": ((tot_exp_len, self.parallel_cnt, p_cnt, p_cnt, pedge_dim), ctypes.c_float), + "mask": ((tot_exp_len, self.parallel_cnt, tick_buffer), ctypes.c_bool) + }, + "a": ((tot_exp_len, self.parallel_cnt), ctypes.c_long), + "len": ((self.parallel_cnt, len(config.env.exp_per_ep)), ctypes.c_long), + "R": ((tot_exp_len, self.parallel_cnt, p_cnt), ctypes.c_float), + } + self.exp_output = SharedStructure(exp_output_structure) + self.exp_output_np = self.exp_output.structuralize() + + self._logger.info("allocate complete") + + self.exp_idx_mapping = OrderedDict() + acc_c = 0 + for key, c in config.env.exp_per_ep.items(): + self.exp_idx_mapping[key] = acc_c + acc_c += c + + self.workers = [ + Process( + target=single_player_worker, + args=(i, config, self.exp_idx_mapping, self.pipes[i][1], self.action_io, self.exp_output) + ) for i in range(self.parallel_cnt) + ] + for w in self.workers: + w.start() + + self._logger.info("all thread started") + + self._roll_out_time = 0 + self._trainsfer_time = 0 + self._roll_out_cnt = 0 + + def roll_out(self): + """Rollout using current policy in the AgentManager. + + Returns: + result (dict): The key is the agent code, the value is the experience list stored in numpy.array. + """ + # Compute the time used for state preparation in the child process. + t_state = 0 + # Compute the time used for action inference. + t_action = 0 + + for p in self.pipes: + p[0].send("reset") + self._roll_out_cnt += 1 + + step_i = 0 + tick = time.time() + while True: + signals = [p[0].recv() for p in self.pipes] + if signals[0] == "done": + break + + step_i += 1 + + t = time.time() + graph = gnn_union( + self.action_io_np["p"], self.action_io_np["po"], self.action_io_np["pedge"], + self.action_io_np["v"], self.action_io_np["vo"], self.action_io_np["vedge"], + self._gnn_state_shaper.p2p_static_graph, self.action_io_np["ppedge"], + self.action_io_np["mask"], self.device + ) + t_state += time.time() - t + + assert(np.min(self.action_io_np["pid"]) == np.max(self.action_io_np["pid"])) + assert(np.min(self.action_io_np["vid"]) == np.max(self.action_io_np["vid"])) + + t = time.time() + actions = self._inference_agents.choose_action( + agent_id=(self.action_io_np["pid"][0], self.action_io_np["vid"][0]), state=graph + ) + t_action += time.time() - t + + for i, p in enumerate(self.pipes): + p[0].send(actions[i]) + + self._roll_out_time += time.time() - tick + tick = time.time() + self._logger.info("receiving exp") + logs = [p[0].recv() for p in self.pipes] + + self._logger.info(f"Mean of shortage: {np.mean(self.action_io_np['sh'])}") + self._trainsfer_time += time.time() - tick + + self._logger.debug(dict(zip(self.log_header, self.action_io_np["sh"]))) + + with open(os.path.join(self.config.log.path, f"logs_{self._roll_out_cnt}"), "wb") as fp: + pickle.dump(logs, fp) + + self._logger.info("organize exp_dict") + result = organize_exp_list(self.exp_output_np, self.exp_idx_mapping) + + if self.config.log.exp.enable and self._roll_out_cnt % self.config.log.exp.freq == 0: + with open(os.path.join(self.config.log.path, f"exp_{self._roll_out_cnt}"), "wb") as fp: + pickle.dump(result, fp) + + self._logger.debug(f"play time: {int(self._roll_out_time)}") + self._logger.debug(f"transfer time: {int(self._trainsfer_time)}") + return result + + def exit(self): + """Terminate the child processes.""" + for p in self.pipes: + p[0].send("close") diff --git a/examples/cim/gnn/actor_critic.py b/examples/cim/gnn/actor_critic.py new file mode 100644 index 000000000..53c10be82 --- /dev/null +++ b/examples/cim/gnn/actor_critic.py @@ -0,0 +1,179 @@ +import os + +import torch +from torch import nn +from torch.distributions import Categorical +from torch.nn.utils import clip_grad + +from examples.cim.gnn.utils import gnn_union +from maro.rl import AbsAlgorithm + + +class ActorCritic(AbsAlgorithm): + """Actor-Critic algorithm in CIM problem. + + The vanilla ac algorithm. + + Args: + model (nn.Module): A actor-critic module outputing both the policy network and the value network. + device (torch.device): A PyTorch device instance where the module is computed on. + p2p_adj (numpy.array): The static port-to-port adjencency matrix. + td_steps (int): The value "n" in the n-step TD algorithm. + gamma (float): The time decay. + learning_rate (float): The learning rate for the module. + entropy_factor (float): The weight of the policy"s entropy to boost exploration. + """ + + def __init__( + self, model: nn.Module, device: torch.device, p2p_adj=None, td_steps=100, gamma=0.97, learning_rate=0.0003, + entropy_factor=0.1): + self._gamma = gamma + self._td_steps = td_steps + self._value_discount = gamma**100 + self._entropy_factor = entropy_factor + self._device = device + self._tot_batchs = 0 + self._p2p_adj = p2p_adj + super().__init__( + model_dict={"a&c": model}, optimizer_opt={"a&c": (torch.optim.Adam, {"lr": learning_rate})}, + loss_func_dict={}, hyper_params=None) + + def choose_action(self, state: dict, p_idx: int, v_idx: int): + """Get action from the AC model. + + Args: + state (dict): A dictionary containing the input to the module. For example: + { + "v": v, + "p": p, + "pe": { + "edge": pedge, + "adj": padj, + "mask": pmask, + }, + "ve": { + "edge": vedge, + "adj": vadj, + "mask": vmask, + }, + "ppe": { + "edge": ppedge, + "adj": p2p_adj, + "mask": p2p_mask, + }, + "mask": seq_mask, + } + p_idx (int): The identity of the port doing the action. + v_idx (int): The identity of the vessel doing the action. + + Returns: + model_action (numpy.int64): The action returned from the module. + """ + with torch.no_grad(): + prob, _ = self._model_dict["a&c"](state, a=True, p_idx=p_idx, v_idx=v_idx) + distribution = Categorical(prob) + model_action = distribution.sample().cpu().numpy() + return model_action + + def train(self, batch, p_idx, v_idx): + """Model training. + + Args: + batch (dict): The dictionary of a batch of experience. For example: + { + "s": the dictionary of state, + "a": model actions in numpy array, + "R": the n-step accumulated reward, + "s"": the dictionary of the next state, + } + p_idx (int): The identity of the port doing the action. + v_idx (int): The identity of the vessel doing the action. + + Returns: + a_loss (float): action loss. + c_loss (float): critic loss. + e_loss (float): entropy loss. + tot_norm (float): the L2 norm of the gradient. + + """ + self._tot_batchs += 1 + item_a_loss, item_c_loss, item_e_loss = 0, 0, 0 + obs_batch = batch["s"] + action_batch = batch["a"] + return_batch = batch["R"] + next_obs_batch = batch["s_"] + + obs_batch = gnn_union( + obs_batch["p"], obs_batch["po"], obs_batch["pedge"], obs_batch["v"], obs_batch["vo"], obs_batch["vedge"], + self._p2p_adj, obs_batch["ppedge"], obs_batch["mask"], self._device) + action_batch = torch.from_numpy(action_batch).long().to(self._device) + return_batch = torch.from_numpy(return_batch).float().to(self._device) + next_obs_batch = gnn_union( + next_obs_batch["p"], next_obs_batch["po"], next_obs_batch["pedge"], next_obs_batch["v"], + next_obs_batch["vo"], next_obs_batch["vedge"], self._p2p_adj, next_obs_batch["ppedge"], + next_obs_batch["mask"], self._device) + + # Train actor network. + self._optimizer["a&c"].zero_grad() + + # Every port has a value. + # values.shape: (batch, p_cnt) + probs, values = self._model_dict["a&c"](obs_batch, a=True, p_idx=p_idx, v_idx=v_idx, c=True) + distribution = Categorical(probs) + log_prob = distribution.log_prob(action_batch) + entropy_loss = distribution.entropy() + + _, values_ = self._model_dict["a&c"](next_obs_batch, c=True) + advantage = return_batch + self._value_discount * values_.detach() - values + + if self._entropy_factor != 0: + # actor_loss = actor_loss* torch.log(entropy_loss + np.e) + advantage[:, p_idx] += self._entropy_factor * entropy_loss.detach() + + actor_loss = - (log_prob * torch.sum(advantage, axis=-1).detach()).mean() + + item_a_loss = actor_loss.item() + item_e_loss = entropy_loss.mean().item() + + # Train critic network. + critic_loss = torch.sum(advantage.pow(2), axis=1).mean() + item_c_loss = critic_loss.item() + # torch.nn.utils.clip_grad_norm_(self._critic_model.parameters(),0.5) + tot_loss = 0.1 * actor_loss + critic_loss + tot_loss.backward() + tot_norm = clip_grad.clip_grad_norm_(self._model_dict["a&c"].parameters(), 1) + self._optimizer["a&c"].step() + return item_a_loss, item_c_loss, item_e_loss, float(tot_norm) + + def set_weights(self, weights): + self._model_dict["a&c"].load_state_dict(weights) + + def get_weights(self): + return self._model_dict["a&c"].state_dict() + + def _get_save_idx(self, fp_str): + return int(fp_str.split(".")[0].split("_")[0]) + + def save_model(self, pth, id): + if not os.path.exists(pth): + os.makedirs(pth) + pth = os.path.join(pth, f"{id}_ac.pkl") + torch.save(self._model_dict["a&c"].state_dict(), pth) + + def _set_gnn_weights(self, weights): + for key in weights: + if key in self._model_dict["a&c"].state_dict().keys(): + self._model_dict["a&c"].state_dict()[key].copy_(weights[key]) + + def load_model(self, folder_pth, idx=-1): + if idx == -1: + fps = os.listdir(folder_pth) + fps = [f for f in fps if "ac" in f] + fps.sort(key=self._get_save_idx) + ac_pth = fps[-1] + else: + ac_pth = f"{idx}_ac.pkl" + pth = os.path.join(folder_pth, ac_pth) + with open(pth, "rb") as fp: + weights = torch.load(fp, map_location=self._device) + self._set_gnn_weights(weights) diff --git a/examples/cim/gnn/agent.py b/examples/cim/gnn/agent.py new file mode 100644 index 000000000..a99171e72 --- /dev/null +++ b/examples/cim/gnn/agent.py @@ -0,0 +1,40 @@ +from collections import defaultdict + +import numpy as np + +from examples.cim.gnn.numpy_store import Shuffler +from maro.rl import AbsAgent +from maro.utils import DummyLogger + + +class TrainableAgent(AbsAgent): + def __init__(self, name, algorithm, experience_pool, logger=DummyLogger()): + self._logger = logger + super().__init__(name, algorithm, experience_pool) + + def train(self, training_config): + loss_dict = defaultdict(list) + for j in range(training_config.shuffle_time): + shuffler = Shuffler(self._experience_pool, batch_size=training_config.batch_size) + while shuffler.has_next(): + batch = shuffler.next() + actor_loss, critic_loss, entropy_loss, tot_loss = self._algorithm.train( + batch, self._name[0], self._name[1]) + loss_dict["actor"].append(actor_loss) + loss_dict["critic"].append(critic_loss) + loss_dict["entropy"].append(entropy_loss) + loss_dict["tot"].append(tot_loss) + + a_loss = np.mean(loss_dict["actor"]) + c_loss = np.mean(loss_dict["critic"]) + e_loss = np.mean(loss_dict["entropy"]) + tot_loss = np.mean(loss_dict["tot"]) + self._logger.debug( + f"code: {str(self._name)} \t actor: {float(a_loss)} \t critic: {float(c_loss)} \t entropy: {float(e_loss)} \ + \t tot: {float(tot_loss)}") + + self._experience_pool.clear() + return loss_dict + + def choose_action(self, model_state): + return self._algorithm.choose_action(model_state, self._name[0], self._name[1]) diff --git a/examples/cim/gnn/agent_manager.py b/examples/cim/gnn/agent_manager.py new file mode 100644 index 000000000..e51dc704d --- /dev/null +++ b/examples/cim/gnn/agent_manager.py @@ -0,0 +1,118 @@ +from copy import copy + +import numpy as np +import torch + +from examples.cim.gnn.agent import TrainableAgent +from examples.cim.gnn.actor_critic import ActorCritic +from examples.cim.gnn.simple_gnn import SharedAC +from examples.cim.gnn.numpy_store import NumpyStore +from examples.cim.gnn.state_shaper import GNNStateShaper +from maro.rl import AbsAgentManager, AgentMode +from maro.utils import DummyLogger + + +class SimpleAgentManger(AbsAgentManager): + def __init__( + self, name, agent_id_list, port_code_list, vessel_code_list, demo_env, state_shaper: GNNStateShaper, + logger=DummyLogger()): + super().__init__( + name, AgentMode.TRAIN, agent_id_list, state_shaper=state_shaper, action_shaper=None, + experience_shaper=None, explorer=None) + self.port_code_list = copy(port_code_list) + self.vessel_code_list = copy(vessel_code_list) + self.demo_env = demo_env + self._logger = logger + + def assemble(self, config): + v_dim, vedge_dim = self._state_shaper.get_input_dim("v"), self._state_shaper.get_input_dim("vedge") + p_dim, pedge_dim = self._state_shaper.get_input_dim("p"), self._state_shaper.get_input_dim("pedge") + + self.device = torch.device(config.training.device) + self._logger.info(config.training.device) + ac_model = SharedAC( + p_dim, pedge_dim, v_dim, vedge_dim, config.model.tick_buffer, config.model.action_dim).to(self.device) + + value_dict = { + ("s", "v"): + ( + (config.model.tick_buffer, len(self.vessel_code_list), self._state_shaper.get_input_dim("v")), + np.float32, False), + ("s", "p"): + ( + (config.model.tick_buffer, len(self.port_code_list), self._state_shaper.get_input_dim("p")), + np.float32, False), + ("s", "vo"): ((len(self.vessel_code_list), len(self.port_code_list)), np.int64, True), + ("s", "po"): ((len(self.port_code_list), len(self.vessel_code_list)), np.int64, True), + ("s", "vedge"): + ( + (len(self.vessel_code_list), len(self.port_code_list), self._state_shaper.get_input_dim("vedge")), + np.float32, True), + ("s", "pedge"): + ( + (len(self.port_code_list), len(self.vessel_code_list), self._state_shaper.get_input_dim("vedge")), + np.float32, True), + ("s", "ppedge"): + ( + (len(self.port_code_list), len(self.port_code_list), self._state_shaper.get_input_dim("pedge")), + np.float32, True), + ("s", "mask"): ((config.model.tick_buffer, ), np.bool, True), + + ("s_", "v"): + ( + (config.model.tick_buffer, len(self.vessel_code_list), self._state_shaper.get_input_dim("v")), + np.float32, False), + ("s_", "p"): + ( + (config.model.tick_buffer, len(self.port_code_list), self._state_shaper.get_input_dim("p")), + np.float32, False), + ("s_", "vo"): ((len(self.vessel_code_list), len(self.port_code_list)), np.int64, True), + ("s_", "po"): + ( + (len(self.port_code_list), len(self.vessel_code_list)), np.int64, True), + ("s_", "vedge"): + ( + (len(self.vessel_code_list), len(self.port_code_list), self._state_shaper.get_input_dim("vedge")), + np.float32, True), + ("s_", "pedge"): + ( + (len(self.port_code_list), len(self.vessel_code_list), self._state_shaper.get_input_dim("vedge")), + np.float32, True), + ("s_", "ppedge"): + ( + (len(self.port_code_list), len(self.port_code_list), self._state_shaper.get_input_dim("pedge")), + np.float32, True), + ("s_", "mask"): ((config.model.tick_buffer, ), np.bool, True), + + # To identify one dimension variable. + ("R",): ((len(self.port_code_list), ), np.float32, True), + ("a",): (tuple(), np.int64, True), + } + + self._algorithm = ActorCritic( + ac_model, self.device, td_steps=config.training.td_steps, p2p_adj=self._state_shaper.p2p_static_graph, + gamma=config.training.gamma, learning_rate=config.training.learning_rate) + + for agent_id, cnt in config.env.exp_per_ep.items(): + experience_pool = NumpyStore(value_dict, config.training.parallel_cnt * config.training.train_freq * cnt) + self._agent_dict[agent_id] = TrainableAgent(agent_id, self._algorithm, experience_pool, self._logger) + + def choose_action(self, agent_id, state): + return self._agent_dict[agent_id].choose_action(state) + + def load_models_from_files(self, model_pth): + self._algorithm.load_model(model_pth) + + def train(self, training_config): + for agent in self._agent_dict.values(): + agent.train(training_config) + + def store_experiences(self, experiences): + for code, exp_list in experiences.items(): + self._agent_dict[code].store_experiences(exp_list) + + def save_model(self, pth, id): + self._algorithm.save_model(pth, id) + + def load_model(self, pth): + self._algorithm.load_model(pth) diff --git a/examples/cim/gnn/config.yml b/examples/cim/gnn/config.yml new file mode 100644 index 000000000..55a249019 --- /dev/null +++ b/examples/cim/gnn/config.yml @@ -0,0 +1,36 @@ +env: + seed: 10 + param: + durations: 1120 + scenario: "cim" + topology: "global_trade.22p_l0.6" + # topology: "toy.4p_ssdd_l0.0" +training: + enable: True + parallel_cnt: 24 + device: "cuda:0" + batch_size: 150 + shuffle_time: 1 + rollout_cnt: 500 + train_freq: 1 + model_save_freq: 1 + gamma: 0.99 + learning_rate: 0.00005 + td_steps: 100 + entropy_loss_enable: True +model: + path: "./" + tick_buffer: 20 + hidden_size: 32 + graph_output_dim: 32 + action_dim: 21 + feature: + # temporal or random, if temporal, the edges in the graph are listed in the order of event time, else in a + # random order. + attention_order: temporal + onehot_identity: False +log: + path: "./" + exp: + enable: false + freq: 10 diff --git a/examples/cim/gnn/experience_shaper.py b/examples/cim/gnn/experience_shaper.py new file mode 100644 index 000000000..1bc6f6352 --- /dev/null +++ b/examples/cim/gnn/experience_shaper.py @@ -0,0 +1,111 @@ +from collections import defaultdict + +import numpy as np + + +class ExperienceShaper: + def __init__( + self, static_list, dynamic_list, max_tick, gnn_state_shaper, scale_factor=0.0001, time_slot=100, + discount_factor=0.97, idx=-1, shared_storage=None, exp_idx_mapping=None): + self._static_list = list(static_list) + self._dynamic_list = list(dynamic_list) + self._time_slot = time_slot + self._discount_factor = discount_factor + self._discount_vector = np.logspace(1, self._time_slot, self._time_slot, base=discount_factor) + self._max_tick = max_tick + self._tick_range = list(range(self._max_tick)) + self._len_return = self._max_tick - self._time_slot + self._gnn_state_shaper = gnn_state_shaper + self._fulfillment_list, self._shortage_list, self._experience_dict = None, None, None + self._experience_dict = defaultdict(list) + self._init_state() + self._idx = idx + self._exp_idx_mapping = exp_idx_mapping + self._shared_storage = shared_storage + self._scale_factor = scale_factor + + def _init_state(self): + self._fulfillment_list, self._shortage_list = np.zeros(self._max_tick + 1), np.zeros(self._max_tick + 1) + self._experience_dict = defaultdict(list) + self._last_tick = 0 + + def record(self, decision_event, model_action, model_input): + # Only the experience that has the next state of given time slot is valuable. + if decision_event.tick + self._time_slot < self._max_tick: + self._experience_dict[decision_event.port_idx, decision_event.vessel_idx].append({ + "tick": decision_event.tick, + "s": model_input, + "a": model_action, + }) + + def _compute_delta(self, arr): + delta = np.array(arr) + delta[1:] -= arr[:-1] + return delta + + def _batch_obs_to_numpy(self, obs): + v = np.stack([o["v"] for o in obs], axis=0) + p = np.stack([o["p"] for o in obs], axis=0) + vo = np.stack([o["vo"] for o in obs], axis=0) + po = np.stack([o["po"] for o in obs], axis=0) + return {"p": p, "v": v, "vo": vo, "po": po} + + def __call__(self, snapshot_list): + if self._shared_storage is None: + return + + shortage = snapshot_list["ports"][self._tick_range:self._static_list:"shortage"].reshape(self._max_tick, -1) + fulfillment = snapshot_list["ports"][self._tick_range:self._static_list:"fulfillment"] \ + .reshape(self._max_tick, -1) + delta = fulfillment - shortage + R = np.empty((self._len_return, len(self._static_list)), dtype=np.float) + for i in range(0, self._len_return, 1): + R[i] = np.dot(self._discount_vector, delta[i + 1: i + self._time_slot + 1]) + + for (agent_idx, vessel_idx), exp_list in self._experience_dict.items(): + for exp in exp_list: + tick = exp["tick"] + exp["s_"] = self._gnn_state_shaper(tick=tick + self._time_slot) + exp["R"] = self._scale_factor * R[tick] + + tmpi = 0 + for (agent_idx, vessel_idx), idx_base in self._exp_idx_mapping.items(): + exp_list = self._experience_dict[(agent_idx, vessel_idx)] + exp_len = len(exp_list) + # Here, we assume that exp_idx_mapping order is not changed. + self._shared_storage["len"][self._idx, tmpi] = exp_len + self._shared_storage["s"]["v"][:, idx_base:idx_base + exp_len, self._idx] = \ + np.stack([e["s"]["v"] for e in exp_list], axis=1) + self._shared_storage["s"]["p"][:, idx_base:idx_base + exp_len, self._idx] = \ + np.stack([e["s"]["p"] for e in exp_list], axis=1) + self._shared_storage["s"]["vo"][idx_base:idx_base + exp_len, self._idx] = \ + np.stack([e["s"]["vo"] for e in exp_list], axis=0) + self._shared_storage["s"]["po"][idx_base:idx_base + exp_len, self._idx] = \ + np.stack([e["s"]["po"] for e in exp_list], axis=0) + self._shared_storage["s"]["vedge"][idx_base:idx_base + exp_len, self._idx] = \ + np.stack([e["s"]["vedge"] for e in exp_list], axis=0) + self._shared_storage["s"]["pedge"][idx_base:idx_base + exp_len, self._idx] = \ + np.stack([e["s"]["pedge"] for e in exp_list], axis=0) + + self._shared_storage["s_"]["v"][:, idx_base:idx_base + exp_len, self._idx] = \ + np.stack([e["s_"]["v"] for e in exp_list], axis=1) + self._shared_storage["s_"]["p"][:, idx_base:idx_base + exp_len, self._idx] = \ + np.stack([e["s_"]["p"] for e in exp_list], axis=1) + self._shared_storage["s_"]["vo"][idx_base:idx_base + exp_len, self._idx] = \ + np.stack([e["s_"]["vo"] for e in exp_list], axis=0) + self._shared_storage["s_"]["po"][idx_base:idx_base + exp_len, self._idx] = \ + np.stack([e["s_"]["po"] for e in exp_list], axis=0) + self._shared_storage["s_"]["vedge"][idx_base:idx_base + exp_len, self._idx] = \ + np.stack([e["s_"]["vedge"] for e in exp_list], axis=0) + self._shared_storage["s_"]["pedge"][idx_base:idx_base + exp_len, self._idx] = \ + np.stack([e["s_"]["pedge"] for e in exp_list], axis=0) + + self._shared_storage["a"][idx_base: idx_base + exp_len, self._idx] = \ + np.array([exp["a"] for exp in exp_list], dtype=np.int64) + self._shared_storage["R"][idx_base: idx_base + exp_len, self._idx] = \ + np.vstack([exp["R"] for exp in exp_list]) + tmpi += 1 + + def reset(self): + del self._experience_dict + self._init_state() diff --git a/examples/cim/gnn/launcher.py b/examples/cim/gnn/launcher.py new file mode 100644 index 000000000..6714e976f --- /dev/null +++ b/examples/cim/gnn/launcher.py @@ -0,0 +1,69 @@ +import os +import datetime + +from examples.cim.gnn.actor import ParallelActor +from examples.cim.gnn.learner import GNNLearner +from examples.cim.gnn.state_shaper import GNNStateShaper +from examples.cim.gnn.utils import decision_cnt_analysis, load_config, save_config, save_code, return_scaler +from examples.cim.gnn.agent_manager import SimpleAgentManger +from maro.simulator import Env +from maro.utils import Logger + + +if __name__ == "__main__": + config_pth = "examples/cim/gnn/config.yml" + config = load_config(config_pth) + + # Generate log path. + date_str = datetime.datetime.now().strftime("%Y%m%d") + time_str = datetime.datetime.now().strftime("%H%M%S.%f") + subfolder_name = f"{config.env.param.topology}_{time_str}" + + # Log path. + config.log.path = os.path.join(config.log.path, date_str, subfolder_name) + if not os.path.exists(config.log.path): + os.makedirs(config.log.path) + + simulation_logger = Logger(tag="simulation", dump_folder=config.log.path, dump_mode="w", auto_timestamp=False) + + # Create a demo environment to retrieve environment information. + simulation_logger.info("Approximating the experience quantity of each agent...") + demo_env = Env(**config.env.param) + config.env.exp_per_ep = decision_cnt_analysis(demo_env, pv=True, buffer_size=8) + simulation_logger.info(config.env.exp_per_ep) + + # Add some buffer to prevent overlapping. + config.env.return_scaler, tot_order_amount = return_scaler( + demo_env, tick=config.env.param.durations, gamma=config.training.gamma) + simulation_logger.info(f"Return value will be scaled down by the factor {config.env.return_scaler}") + + save_config(config, os.path.join(config.log.path, "config.yml")) + save_code("examples/cim/gnn", config.log.path) + + port_mapping = demo_env.summary["node_mapping"]["ports"] + vessel_mapping = demo_env.summary["node_mapping"]["vessels"] + + # Create a mock gnn_state_shaper. + static_code_list, dynamic_code_list = list(port_mapping.values()), list(vessel_mapping.values()) + gnn_state_shaper = GNNStateShaper( + static_code_list, dynamic_code_list, config.env.param.durations, config.model.feature, + tick_buffer=config.model.tick_buffer, only_demo=True, max_value=demo_env.configs["total_containers"]) + gnn_state_shaper.compute_static_graph_structure(demo_env) + + # Create and assemble agent_manager. + agent_id_list = list(config.env.exp_per_ep.keys()) + training_logger = Logger(tag="training", dump_folder=config.log.path, dump_mode="w", auto_timestamp=False) + agent_manager = SimpleAgentManger( + "CIM-GNN-manager", agent_id_list, static_code_list, dynamic_code_list, demo_env, gnn_state_shaper, + training_logger) + agent_manager.assemble(config) + + # Create the rollout actor to collect experience. + actor = ParallelActor(config, demo_env, gnn_state_shaper, agent_manager, logger=simulation_logger) + + # Learner function for training and testing. + learner = GNNLearner(actor, agent_manager, logger=simulation_logger) + learner.train(config.training) + + # Cancel all the child process used for rollout. + actor.exit() diff --git a/examples/cim/gnn/learner.py b/examples/cim/gnn/learner.py new file mode 100644 index 000000000..2955d4ce1 --- /dev/null +++ b/examples/cim/gnn/learner.py @@ -0,0 +1,51 @@ +import time +import os + +from examples.cim.gnn.agent_manager import SimpleAgentManger +from examples.cim.gnn.actor import ParallelActor +from maro.rl import AbsLearner +from maro.utils import DummyLogger + + +class GNNLearner(AbsLearner): + """Learner class for the training pipeline and the specialized logging in GNN solution for CIM problem. + + Args: + actor (AbsActor): The actor instance to collect experience. + trainable_agents (AbsAgentManager): The agent manager for training RL models. + logger (Logger): The logger to save/print the message. + """ + + def __init__(self, actor: ParallelActor, trainable_agents: SimpleAgentManger, logger=DummyLogger()): + super().__init__() + self._actor = actor + self._trainable_agents = trainable_agents + self._logger = logger + + def train(self, training_config, log_pth=None): + rollout_time = 0 + training_time = 0 + for i in range(training_config.rollout_cnt): + self._logger.info(f"rollout {i + 1}") + tick = time.time() + exp_dict = self._actor.roll_out() + + rollout_time += time.time() - tick + + self._logger.info("start putting exps") + self._trainable_agents.store_experiences(exp_dict) + + if training_config.enable and i % training_config.train_freq == training_config.train_freq - 1: + self._logger.info("training start") + tick = time.time() + self._trainable_agents.train(training_config) + training_time += time.time() - tick + + if log_pth is not None and (i + 1) % training_config.model_save_freq == 0: + self._trainable_agents.save_model(os.path.join(log_pth, "models"), i + 1) + + self._logger.debug(f"total rollout_time: {int(rollout_time)}") + self._logger.debug(f"train_time: {int(training_time)}") + + def test(self): + pass diff --git a/examples/cim/gnn/numpy_store.py b/examples/cim/gnn/numpy_store.py new file mode 100644 index 000000000..c8d8e6d3b --- /dev/null +++ b/examples/cim/gnn/numpy_store.py @@ -0,0 +1,185 @@ +import numpy as np +from typing import Sequence + +from maro.rl import AbsStore + + +def get_item(data_dict, key_tuple): + """Helper function to get the value in a hierarchical dictionary given the key path. + + Args: + data_dict (dict): The data structure. For example: + { + "a": { + "b": 1, + "c": { + "d": 2, + } + } + } + + key_tuple (tuple): The key path to the target field. For example, given the data_dict above, the key_tuple + ("a", "c", "d") should return 2. + """ + for key in key_tuple: + data_dict = data_dict[key] + return data_dict + + +def set_item(data_dict, key_tuple, data): + """The setter function corresponding to the get_item function.""" + for i, key in enumerate(key_tuple): + if key not in data_dict: + data_dict[key] = {} + if i == len(key_tuple) - 1: + data_dict[key] = data + else: + data_dict = data_dict[key] + + +class NumpyStore(AbsStore): + def __init__(self, domain_type_dict, capacity): + """ + Args: + domain_type_dict (dict): The dictionary describing the name, structure and type of each field in the + experience. Each field in the experience is the key-value pair in the folowing structure: + (field_name): (size_of_an_instance, data_type, batch_first) + + For example: + ("s"): ((32, 64), np.float32, True) + + The field can be a hierarchical dictionary by identifying the full path to the root. + + For example: + { + ("s", "p"): ((32, 64), np.float32, True) + ("s", "v"): ((48, ), np.float32, False), + } + Then the batch of experience returned by self.get(indexes) is: + { + "s": + { + "p": numpy.array with size (batch, 32, 64), + "v": numpy.array with size (32, batch, 48), + } + } + Note that for the field ("s", "v"), the batch is in the 2nd dimension because the batch_first attribute + is False. + + capacity (int): The maximum stored experience in the store. + """ + super().__init__() + self.domain_type_dict = dict(domain_type_dict) + self.store = { + key: np.zeros( + shape=(capacity, *shape) if batch_first else (shape[0], capacity, *shape[1:]), dtype=data_type) + for key, (shape, data_type, batch_first) in domain_type_dict.items()} + self.batch_first_store = {key: batch_first for key, (_, _, batch_first) in domain_type_dict.items()} + + self.cnt = 0 + self.capacity = capacity + + def put(self, exp_dict: dict): + """Insert a batch of experience into the store. + + If the store reaches the maximum capacity, this function will replace the experience in the store randomly. + + Args: + exp_dict (dict): The dictionary of a batch of experience. For example: + + { + "s": + { + "p": numpy.array with size (batch, 32, 64), + "v": numpy.array with size (32, batch, 48), + } + } + + The structure should be consistent with the structure defined in the __init__ function. + + Returns: + indexes (numpy.array): The list of the indexes each experience in the batch is located in. + """ + dlen = exp_dict["len"] + append_end = min(max(self.capacity - self.cnt, 0), dlen) + idxs = np.zeros(dlen, dtype=np.int) + if append_end != 0: + for key in self.domain_type_dict.keys(): + data = get_item(exp_dict, key) + if self.batch_first_store[key]: + self.store[key][self.cnt: self.cnt + append_end] = data[0: append_end] + else: + self.store[key][:, self.cnt: self.cnt + append_end] = data[:, 0: append_end] + idxs[: append_end] = np.arange(self.cnt, self.cnt + append_end) + if append_end < dlen: + replace_idx = self._get_replace_idx(dlen - append_end) + for key in self.domain_type_dict.keys(): + data = get_item(exp_dict, key) + if self.batch_first_store[key]: + self.store[key][replace_idx] = data[append_end: dlen] + else: + self.store[key][:, replace_idx] = data[:, append_end: dlen] + idxs[append_end: dlen] = replace_idx + self.cnt += dlen + return idxs + + def _get_replace_idx(self, cnt): + return np.random.randint(low=0, high=self.capacity, size=cnt) + + def get(self, indexes: np.array): + """Get the experience indexed in the indexes list from the store. + + Args: + indexes (np.array): A numpy array containing the indexes of a batch experience. + + Returns: + data_dict (dict): the structure same as that defined in the __init__ function. + """ + data_dict = {} + for key in self.domain_type_dict.keys(): + if self.batch_first_store[key]: + set_item(data_dict, key, self.store[key][indexes]) + else: + set_item(data_dict, key, self.store[key][:, indexes]) + return data_dict + + def __len__(self): + return min(self.capacity, self.cnt) + + def update(self, indexes: Sequence, contents: Sequence): + raise NotImplementedError("NumpyStore does not support modifying the experience!") + + def sample(self, size, weights: Sequence, replace: bool = True): + raise NotImplementedError("NumpyStore does not support sampling. Please use outer sampler to fetch samples!") + + def clear(self): + """Remove all the experience in the store.""" + self.cnt = 0 + + +class Shuffler: + def __init__(self, store: NumpyStore, batch_size: int): + """The helper class for fast batch sampling. + + Args: + store (NumpyStore): The data source for sampling. + batch_size (int): The size of a batch. + """ + self._store = store + self._shuffled_seq = np.arange(0, len(store)) + np.random.shuffle(self._shuffled_seq) + self._start = 0 + self._batch_size = batch_size + + def next(self): + """Uniformly sampling out a batch in the store.""" + if self._start >= len(self._store): + return None + end = min(self._start + self._batch_size, len(self._store)) + rst = self._store.get(self._shuffled_seq[self._start: end]) + self._start += self._batch_size + return rst + + def has_next(self): + """Check if any experience is not visited.""" + return self._start < len(self._store) diff --git a/examples/cim/gnn/shared_structure.py b/examples/cim/gnn/shared_structure.py new file mode 100644 index 000000000..b79a44370 --- /dev/null +++ b/examples/cim/gnn/shared_structure.py @@ -0,0 +1,46 @@ +import multiprocessing + +import numpy as np + + +def init_shared_memory(data_structure): + """Initialize the data structure of the shared memory. + + Args: + data_structure: The dictionary that describes the data structure. For example, + { + "a": (shape, type), + "b": { + "b1": (shape, type), + } + } + """ + if isinstance(data_structure, tuple): + mult = 1 + for i in data_structure[0]: + mult *= i + return multiprocessing.Array(data_structure[1], mult, lock=False) + else: + shared_data = {} + for k, v in data_structure.items(): + shared_data[k] = init_shared_memory(v) + return shared_data + + +def shared_data2numpy(shared_data, structure_info): + if not isinstance(shared_data, dict): + return np.frombuffer(shared_data, dtype=structure_info[1]).reshape(structure_info[0]) + else: + numpy_dict = {} + for k, v in shared_data.items(): + numpy_dict[k] = shared_data2numpy(v, structure_info[k]) + return numpy_dict + + +class SharedStructure: + def __init__(self, data_structure): + self.data_structure = data_structure + self.shared = init_shared_memory(data_structure) + + def structuralize(self): + return shared_data2numpy(self.shared, self.data_structure) diff --git a/examples/cim/gnn/simple_gnn.py b/examples/cim/gnn/simple_gnn.py new file mode 100644 index 000000000..61d5eefb6 --- /dev/null +++ b/examples/cim/gnn/simple_gnn.py @@ -0,0 +1,335 @@ +import math + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F +from torch.nn import TransformerEncoder, TransformerEncoderLayer +from torch.nn.modules.activation import MultiheadAttention +from torch.nn.modules.dropout import Dropout +from torch.nn.modules.normalization import LayerNorm + + +class PositionalEncoder(nn.Module): + """ + The positional encoding used in transformer to get the sequential information. + + The code is based on the PyTorch version in web + https://pytorch.org/tutorials/beginner/transformer_tutorial.html?highlight=positionalencoding + """ + + def __init__(self, d_model, max_seq_len=80): + super().__init__() + self.d_model = d_model + self.times = 4 * math.sqrt(self.d_model) + + # Create constant "pe" matrix with values dependant on pos and i. + self.pe = torch.zeros(max_seq_len, d_model) + for pos in range(max_seq_len): + for i in range(0, d_model, 2): + self.pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / d_model))) + self.pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model))) + + self.pe = self.pe.unsqueeze(1) / self.d_model + + def forward(self, x): + # Make embeddings relatively larger. + addon = self.pe[: x.shape[0], :, : x.shape[2]].to(x.get_device()) + return x + addon + + +class SimpleGATLayer(nn.Module): + """The enhanced graph attention layer for heterogenenous neighborhood. + + It first utilizes pre-layers for both the source and destination node to map their features into the same hidden + size. If the edge also has features, they are concatenated with those of the corresponding source node before being + fed to the pre-layers. Then the graph attention(https://arxiv.org/abs/1710.10903) is done to aggregate information + from the source nodes to the destination nodes. The residual connection and layer normalization are also used to + enhance the performance, which is similar to the Transformer(https://arxiv.org/abs/1706.03762). + + Args: + src_dim (int): The feature dimension of the source nodes. + dest_dim (int): The feature dimension of the destination nodes. + edge_dim (int): The feature dimension of the edges. If the edges have no feature, it should be set 0. + hidden_size (int): The hidden size both the destination and source is mapped into. + nhead (int): The number of head in the multi-head attention. + position_encoding (bool): the neighbor source nodes is aggregated in order(True) or orderless(False). + """ + + def __init__(self, src_dim, dest_dim, edge_dim, hidden_size, nhead=4, position_encoding=True): + super().__init__() + self.src_dim = src_dim + self.dest_dim = dest_dim + self.edge_dim = edge_dim + self.hidden_size = hidden_size + self.nhead = nhead + src_layers = [] + src_layers.append(nn.Linear(src_dim + edge_dim, hidden_size)) + src_layers.append(GeLU()) + self.src_pre_layer = nn.Sequential(*src_layers) + + dest_layers = [] + dest_layers.append(nn.Linear(dest_dim, hidden_size)) + dest_layers.append(GeLU()) + self.dest_pre_layer = nn.Sequential(*dest_layers) + + self.att = MultiheadAttention(embed_dim=hidden_size, num_heads=nhead) + self.att_dropout = Dropout(0.1) + self.att_norm = LayerNorm(hidden_size) + + self.zero_padding_template = torch.zeros((1, src_dim), dtype=torch.float) + + def forward(self, src: Tensor, dest: Tensor, adj: Tensor, mask: Tensor, edges: Tensor = None): + """Information aggregation from the source nodes to the destination nodes. + + Args: + src (Tensor): The source nodes in a batch of graph. + dest (Tensor): The destination nodes in a batch of graph. + adj (Tensor): The adjencency list stored in a 2D matrix in the batch-second format. The first dimension is + the maximum amount of the neighbors the destinations have. As the neighbor quantities vary from one + destination to another, the short sequences are padded with 0. + mask (Tensor): The mask identifies if a position in the adj is padded. Note that it is stored in the + batch-first format. + + Returns: + destination_emb: The embedding of the destinations after the GAT layer. + + Shape: + src: (batch, src_cnt, src_dim) + dest: (batch, dest_cnt, dest_dim) + adj: (src_neighbor_cnt, batch*dest_cnt) + mask: (batch*dest_cnt)*src_neighbor_cnt + edges: (batch*dest_cnt, src_neighbor_cnt, edge_dim) + destination_emb: (batch, dest_cnt, hidden_size) + + """ + assert(self.src_dim == src.shape[-1]) + assert(self.dest_dim == dest.shape[-1]) + batch, s_cnt, src_dim = src.shape + batch, d_cnt, dest_dim = dest.shape + src_neighbor_cnt = adj.shape[0] + + src_embedding = src.reshape(-1, src_dim) + src_embedding = torch.cat((self.zero_padding_template.to(src_embedding.get_device()), src_embedding)) + + flat_adj = adj.reshape(-1) + src_embedding = src_embedding[flat_adj].reshape(src_neighbor_cnt, -1, src_dim) + if edges is not None: + src_embedding = torch.cat((src_embedding, edges), axis=2) + + src_input = self.src_pre_layer( + src_embedding.reshape(-1, src_dim + self.edge_dim)). \ + reshape(*src_embedding.shape[:2], self.hidden_size) + dest_input = self.dest_pre_layer(dest.reshape(-1, dest_dim)).reshape(1, batch * d_cnt, self.hidden_size) + dest_emb, _ = self.att(dest_input, src_input, src_input, key_padding_mask=mask) + + dest_emb = dest_emb + self.att_dropout(dest_emb) + dest_emb = self.att_norm(dest_emb) + return dest_emb.reshape(batch, d_cnt, self.hidden_size) + + +class SimpleTransformer(nn.Module): + """Graph attention network with multiple graph in the CIM scenario. + + This module aggregates information in the port-to-port graph, port-to-vessel graph and vessel-to-port graph. The + aggregation in the two graph are done separatedly and then the port features are concatenated as the final result. + + Args: + p_dim (int): The feature dimension of the ports. + v_dim (int): The feature dimension of the vessels. + edge_dim (dict): The key is the edge name and the value is the corresponding feature dimension. + output_size (int): The hidden size in graph attention. + layer_num (int): The number of graph attention layers in each graph. + """ + + def __init__(self, p_dim, v_dim, edge_dim: dict, output_size, layer_num=2): + super().__init__() + self.hidden_size = output_size + self.layer_num = layer_num + + pl, vl, ppl = [], [], [] + for i in range(layer_num): + if i == 0: + pl.append(SimpleGATLayer(v_dim, p_dim, edge_dim["v"], self.hidden_size, nhead=4)) + vl.append(SimpleGATLayer(p_dim, v_dim, edge_dim["v"], self.hidden_size, nhead=4)) + # p2p links. + ppl.append( + SimpleGATLayer( + p_dim, p_dim, edge_dim["p"], self.hidden_size, nhead=4, position_encoding=False) + ) + else: + pl.append(SimpleGATLayer(self.hidden_size, self.hidden_size, 0, self.hidden_size, nhead=4)) + if i != layer_num - 1: + # p2v conv is not necessary at the last layer, for we only use port features. + vl.append(SimpleGATLayer(self.hidden_size, self.hidden_size, 0, self.hidden_size, nhead=4)) + ppl.append(SimpleGATLayer( + self.hidden_size, self.hidden_size, 0, self.hidden_size, nhead=4, position_encoding=False)) + self.p_layers = nn.ModuleList(pl) + self.v_layers = nn.ModuleList(vl) + self.pp_layers = nn.ModuleList(ppl) + + def forward(self, p, pe, v, ve, ppe): + """Do the multi-channel graph attention. + + Args: + p (Tensor): The port feature. + pe (Tensor): The vessel-port edge feature. + v (Tensor): The vessel feature. + ve (Tensor): The port-vessel edge feature. + ppe (Tensor): The port-port edge feature. + """ + # p.shape: (batch*p_cnt, p_dim) + pp = p + pre_p, pre_v, pre_pp = p, v, pp + for i in range(self.layer_num): + # Only feed edge info in the first layer. + p = self.p_layers[i](pre_v, pre_p, adj=pe["adj"], edges=pe["edge"] if i == 0 else None, mask=pe["mask"]) + if i != self.layer_num - 1: + v = self.v_layers[i]( + pre_p, pre_v, adj=ve["adj"], edges=ve["edge"] if i == 0 else None, mask=ve["mask"]) + pp = self.pp_layers[i]( + pre_pp, pre_pp, adj=ppe["adj"], edges=ppe["edge"] if i == 0 else None, mask=ppe["mask"]) + pre_p, pre_v, pre_pp = p, v, pp + p = torch.cat((p, pp), axis=2) + return p, v + + +class GeLU(nn.Module): + """Simple gelu wrapper as a independent module.""" + def __init__(self): + super().__init__() + + def forward(self, input): + return F.gelu(input) + + +class Header(nn.Module): + def __init__(self, input_size, hidden_size, output_size, net_type="res"): + super().__init__() + self.net_type = net_type + if net_type == "res": + self.fc_0 = nn.Linear(input_size, hidden_size) + self.act_0 = GeLU() + # self.do_0 = Dropout(dropout) + self.fc_1 = nn.Linear(hidden_size, input_size) + self.act_1 = GeLU() + self.fc_2 = nn.Linear(input_size, output_size) + elif net_type == "2layer": + self.fc_0 = nn.Linear(input_size, hidden_size) + self.act_0 = GeLU() + # self.do_0 = Dropout(dropout) + self.fc_1 = nn.Linear(hidden_size, hidden_size // 2) + self.act_1 = GeLU() + self.fc_2 = nn.Linear(hidden_size // 2, output_size) + elif net_type == "1layer": + self.fc_0 = nn.Linear(input_size, hidden_size) + self.act_0 = GeLU() + self.fc_1 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + if self.net_type == "res": + x1 = self.act_0(self.fc_0(x)) + x1 = self.act_1(self.fc_1(x1) + x) + return self.fc_2(x1) + elif self.net_type == "2layer": + x = self.act_0(self.fc_0(x)) + x = self.act_1(self.fc_1(x)) + x = self.fc_1(x) + return x + else: + x = self.fc_1(self.act_0(self.fc_0(x))) + return x + + +class SharedAC(nn.Module): + """The actor-critic module shared with multiple agents. + + This module maps the input graph of the observation to the policy and value space. It first extracts the temporal + information separately for each node with a small transformer block and then extracts the spatial information with + a multi-graph/channel graph attention. Finally, the extracted feature embedding is fed to a actor header as well + as a critic layer, which are the two MLPs with residual connections. + """ + + def __init__( + self, input_dim_p, edge_dim_p, input_dim_v, edge_dim_v, tick_buffer, action_dim, a=True, c=True, + scale=4, ac_head="res"): + super().__init__() + assert(a or c) + self.a, self.c = a, c + self.input_dim_v = input_dim_v + self.input_dim_p = input_dim_p + self.tick_buffer = tick_buffer + + self.pre_dim_v, self.pre_dim_p = 8 * scale, 16 * scale + self.p_pre_layer = nn.Sequential( + nn.Linear(input_dim_p, self.pre_dim_p), GeLU(), PositionalEncoder( + d_model=self.pre_dim_p, max_seq_len=tick_buffer)) + self.v_pre_layer = nn.Sequential( + nn.Linear(input_dim_v, self.pre_dim_v), GeLU(), PositionalEncoder( + d_model=self.pre_dim_v, max_seq_len=tick_buffer)) + p_encoder_layer = TransformerEncoderLayer( + d_model=self.pre_dim_p, nhead=4, activation="gelu", dim_feedforward=self.pre_dim_p * 4) + v_encoder_layer = TransformerEncoderLayer( + d_model=self.pre_dim_v, nhead=2, activation="gelu", dim_feedforward=self.pre_dim_v * 4) + + # Alternative initialization: define the normalization. + # self.trans_layer_p = TransformerEncoder(p_encoder_layer, num_layers=3, norm=Norm(self.pre_dim_p)) + # self.trans_layer_v = TransformerEncoder(v_encoder_layer, num_layers=3, norm=Norm(self.pre_dim_v)) + self.trans_layer_p = TransformerEncoder(p_encoder_layer, num_layers=3) + self.trans_layer_v = TransformerEncoder(v_encoder_layer, num_layers=3) + + self.gnn_output_size = 32 * scale + self.trans_gat = SimpleTransformer( + p_dim=self.pre_dim_p, + v_dim=self.pre_dim_v, + output_size=self.gnn_output_size // 2, + edge_dim={"p": edge_dim_p, "v": edge_dim_v}, + layer_num=2 + ) + + if a: + self.policy_hidden_size = 16 * scale + self.a_input = 3 * self.gnn_output_size // 2 + self.actor = nn.Sequential( + Header(self.a_input, self.policy_hidden_size, action_dim, ac_head), nn.Softmax(dim=-1)) + if c: + self.value_hidden_size = 16 * scale + self.c_input = self.gnn_output_size + self.critic = Header(self.c_input, self.value_hidden_size, 1, ac_head) + + def forward(self, state, a=False, p_idx=None, v_idx=None, c=False): + assert((a and p_idx is not None and v_idx is not None) or c) + feature_p, feature_v = state["p"], state["v"] + + tb, bsize, p_cnt, _ = feature_p.shape + v_cnt = feature_v.shape[2] + assert(tb == self.tick_buffer) + + # Before: feature_p.shape: (tick_buffer, batch_size, p_cnt, p_dim) + # After: feature_p.shape: (tick_buffer, batch_size*p_cnt, p_dim) + feature_p = self.p_pre_layer(feature_p.reshape(feature_p.shape[0], -1, feature_p.shape[-1])) + # state["mask"]: (batch_size, tick_buffer) + # mask_p: (batch_size, p_cnt, tick_buffer) + mask_p = state["mask"].repeat(1, p_cnt).reshape(-1, self.tick_buffer) + feature_p = self.trans_layer_p(feature_p, src_key_padding_mask=mask_p) + + feature_v = self.v_pre_layer(feature_v.reshape(feature_v.shape[0], -1, feature_v.shape[-1])) + mask_v = state["mask"].repeat(1, v_cnt).reshape(-1, self.tick_buffer) + feature_v = self.trans_layer_v(feature_v, src_key_padding_mask=mask_v) + + feature_p = feature_p[0].reshape(bsize, p_cnt, self.pre_dim_p) + feature_v = feature_v[0].reshape(bsize, v_cnt, self.pre_dim_v) + + emb_p, emb_v = self.trans_gat(feature_p, state["pe"], feature_v, state["ve"], state["ppe"]) + + a_rtn, c_rtn = None, None + if a and self.a: + ap = emb_p.reshape(bsize, p_cnt, self.gnn_output_size) + ap = ap[:, p_idx, :] + av = emb_v.reshape(bsize, v_cnt, self.gnn_output_size // 2) + av = av[:, v_idx, :] + emb_a = torch.cat((ap, av), axis=1) + a_rtn = self.actor(emb_a) + if c and self.c: + c_rtn = self.critic(emb_p).reshape(bsize, p_cnt) + return a_rtn, c_rtn diff --git a/examples/cim/gnn/state_shaper.py b/examples/cim/gnn/state_shaper.py new file mode 100644 index 000000000..cf360a178 --- /dev/null +++ b/examples/cim/gnn/state_shaper.py @@ -0,0 +1,234 @@ +import numpy as np + +from maro.rl.shaping.state_shaper import StateShaper +from examples.cim.gnn.utils import compute_v2p_degree_matrix + + +class GNNStateShaper(StateShaper): + """State shaper to extract graph information. + + Args: + port_code_list (list): The list of the port codes in the CIM topology. + vessel_code_list (list): The list of the vessel code in the CIM topology. + max_tick (int): The duration of the simulation. + feature_config (dict): The dottable dict that stores the configuration of the observation feature. + max_value (int): The norm scale. All the feature are simply divided by this number. + tick_buffer (int): The value n in n-step TD. + only_demo (bool): Define if the shaper instance is used only for shape demonstration(True) or runtime + shaping(False). + """ + + def __init__( + self, port_code_list, vessel_code_list, max_tick, feature_config, max_value=100000, tick_buffer=20, + only_demo=False): + # Collect and encode all ports. + self.port_code_list = list(port_code_list) + self.port_cnt = len(self.port_code_list) + self.port_code_inv_dict = {code: i for i, code in enumerate(self.port_code_list)} + + # Collect and encode all vessels. + self.vessel_code_list = list(vessel_code_list) + self.vessel_cnt = len(self.vessel_code_list) + self.vessel_code_inv_dict = {code: i for i, code in enumerate(self.vessel_code_list)} + + # Collect and encode ports and vessels together. + self.node_code_inv_dict_p = {i: i for i in self.port_code_list} + self.node_code_inv_dict_v = {i: i + self.port_cnt for i in self.vessel_code_list} + self.node_cnt = self.port_cnt + self.vessel_cnt + + one_hot_coding = np.identity(self.node_cnt) + self.port_one_hot_coding = np.expand_dims(one_hot_coding[:self.port_cnt], axis=0) + self.vessel_one_hot_coding = np.expand_dims(one_hot_coding[self.port_cnt:], axis=0) + self.last_tick = -1 + + self.port_features = [ + "empty", "full", "capacity", "on_shipper", "on_consignee", "booking", "acc_booking", "shortage", + "acc_shortage", "fulfillment", "acc_fulfillment"] + self.vessel_features = ["empty", "full", "capacity", "remaining_space"] + + self._max_tick = max_tick + self._tick_buffer = tick_buffer + # To identify one vessel would never arrive at the port. + self.max_arrival_time = 99999999 + + self.vedge_dim = 2 + self.pedge_dim = 1 + + self._only_demo = only_demo + self._feature_config = feature_config + self._normalize = True + self._norm_scale = 2.0 / max_value + if not only_demo: + self._state_dict = { + # Last "tick" is used for embedding, all zero and never be modified. + "v": np.zeros((self._max_tick + 1, self.vessel_cnt, self.get_input_dim("v"))), + "p": np.zeros((self._max_tick + 1, self.port_cnt, self.get_input_dim("p"))), + "vo": np.zeros((self._max_tick + 1, self.vessel_cnt, self.port_cnt), dtype=np.int), + "po": np.zeros((self._max_tick + 1, self.port_cnt, self.vessel_cnt), dtype=np.int), + "vedge": np.zeros((self._max_tick + 1, self.vessel_cnt, self.port_cnt, self.get_input_dim("vedge"))), + "pedge": np.zeros((self._max_tick + 1, self.port_cnt, self.vessel_cnt, self.get_input_dim("vedge"))), + "ppedge": np.zeros((self._max_tick + 1, self.port_cnt, self.port_cnt, self.get_input_dim("pedge"))), + } + + # Fixed order: in the order of degree. + + def compute_static_graph_structure(self, env): + v2p_adj_matrix = compute_v2p_degree_matrix(env) + p2p_adj_matrix = np.dot(v2p_adj_matrix.T, v2p_adj_matrix) + p2p_adj_matrix[p2p_adj_matrix == 0] = self.max_arrival_time + np.fill_diagonal(p2p_adj_matrix, self.max_arrival_time) + self._p2p_embedding = self.sort(p2p_adj_matrix) + + v2p_adj_matrix = -v2p_adj_matrix + v2p_adj_matrix[v2p_adj_matrix == 0] = self.max_arrival_time + self._fixed_v_order = self.sort(v2p_adj_matrix) + self._fixed_p_order = self.sort(v2p_adj_matrix.T) + + @property + def p2p_static_graph(self): + return self._p2p_embedding + + def sort(self, arrival_time, attr=None): + """ + Given the arrival time matrix, this function sort the matrix and return the index matrix in the order of + arrival time + """ + n, m = arrival_time.shape + if self._feature_config.attention_order == "ramdom": + arrival_time = arrival_time + np.random.randint(self._max_tick, size=arrival_time.shape) + at_index = np.argsort(arrival_time, axis=1) + if attr is not None: + idx_tmp = np.repeat(at_index, attr.shape[-1]).reshape(*at_index.shape, attr.shape[-1]) + attr = np.take_along_axis(attr, idx_tmp, axis=1) + mask = np.sort(arrival_time, axis=1) >= self.max_arrival_time + at_index += 1 + at_index[mask] = 0 + if attr is None: + return at_index + else: + return at_index, attr + + def end_ep_callback(self, snapshot_list): + if self._only_demo: + return + tick_range = np.arange(start=self.last_tick, stop=self._max_tick) + self._sync_raw_features(snapshot_list, list(tick_range)) + self.last_tick = -1 + + def _sync_raw_features(self, snapshot_list, tick_range, static_code=None, dynamic_code=None): + """This function update the state_dict from snapshot_list in the given tick_range.""" + if len(tick_range) == 0: + # This occurs when two actions happen at the same tick. + return + + # One dim features. + port_naive_feature = snapshot_list["ports"][tick_range: self.port_code_list: self.port_features] \ + .reshape(len(tick_range), self.port_cnt, -1) + # Number of laden from source to destination. + full_on_port = snapshot_list["matrices"][tick_range::"full_on_ports"].reshape( + len(tick_range), self.port_cnt, self.port_cnt) + # Normalize features to a small range. + port_state_mat = self.normalize(port_naive_feature) + + if self._feature_config.onehot_identity: + # Add onehot vector to identify port and vessel. + port_onehot = np.repeat(self.port_one_hot_coding, len(tick_range), axis=0) + if static_code is not None and dynamic_code is not None: + # Identify the decision vessel at the decision port. + port_onehot[-1, self.port_code_inv_dict[static_code], self.node_code_inv_dict_v[dynamic_code]] = -1 + port_state_mat = np.concatenate([port_state_mat, port_onehot], axis=2) + self._state_dict["p"][tick_range] = port_state_mat + + vessel_naive_feature = snapshot_list["vessels"][tick_range:self.vessel_code_list: self.vessel_features] \ + .reshape(len(tick_range), self.vessel_cnt, -1) + full_on_vessel = snapshot_list["matrices"][tick_range::"full_on_vessels"].reshape( + len(tick_range), self.vessel_cnt, self.port_cnt) + + vessel_state_mat = self.normalize(vessel_naive_feature) + if self._feature_config.onehot_identity: + vessel_state_mat = np.concatenate( + [vessel_state_mat, np.repeat(self.vessel_one_hot_coding, len(tick_range), axis=0)], axis=2) + self._state_dict["v"][tick_range] = vessel_state_mat + + # last_arrival_time.shape: vessel_cnt * port_cnt + # -1 means one vessel never stops at the port + vessel_arrival_time = snapshot_list["matrices"][tick_range[-1]:: "vessel_plans"].reshape( + self.vessel_cnt, self.port_cnt) + # Use infinity time to identify vessels never arrive at the port. + last_arrival_time = vessel_arrival_time + 1 + last_arrival_time[last_arrival_time == 0] = self.max_arrival_time + if static_code is not None and dynamic_code is not None: + # To differentiate vessel acting on the port and other vessels that have taken or wait to take actions. + last_arrival_time[self.vessel_code_inv_dict[dynamic_code], self.port_code_inv_dict[static_code]] = 0 + + # Here, we assume that the order of arriving time between two action/event is all the same. + vedge_raw = self.normalize(np.stack((full_on_vessel[-1], last_arrival_time), axis=-1)) + vo, vedge = self.sort(last_arrival_time, attr=vedge_raw) + po, pedge = self.sort(last_arrival_time.T, attr=vedge_raw.transpose((1, 0, 2))) + self._state_dict["vo"][tick_range] = np.expand_dims(vo, axis=0) + self._state_dict["vedge"][tick_range] = np.expand_dims(vedge, axis=0) + self._state_dict["po"][tick_range] = np.expand_dims(po, axis=0) + self._state_dict["pedge"][tick_range] = np.expand_dims(pedge, axis=0) + self._state_dict["ppedge"][tick_range] = self.normalize(full_on_port[-1]).reshape(1, *full_on_port[-1].shape, 1) + + def __call__(self, action_info=None, snapshot_list=None, tick=None): + if self._only_demo: + return + assert((action_info is not None and snapshot_list is not None) or tick is not None) + + if action_info is not None and snapshot_list is not None: + # Update the state dict. + static_code = action_info.port_idx + dynamic_code = action_info.vessel_idx + if self.last_tick == action_info.tick: + tick_range = [action_info.tick] + else: + tick_range = list(range(self.last_tick + 1, action_info.tick + 1, 1)) + + self.last_tick = action_info.tick + self._sync_raw_features(snapshot_list, tick_range, static_code, dynamic_code) + tick = action_info.tick + + # State_tick_range is inverse order. + state_tick_range = np.arange(tick, max(-1, tick - self._tick_buffer), -1) + v = np.zeros((self._tick_buffer, self.vessel_cnt, self.get_input_dim("v"))) + v[:len(state_tick_range)] = self._state_dict["v"][state_tick_range] + p = np.zeros((self._tick_buffer, self.port_cnt, self.get_input_dim("p"))) + p[:len(state_tick_range)] = self._state_dict["p"][state_tick_range] + + # True means padding. + mask = np.ones(self._tick_buffer, dtype=np.bool) + mask[:len(state_tick_range)] = False + ret = { + "tick": state_tick_range, + "v": v, + "p": p, + "vo": self._state_dict["vo"][tick], + "po": self._state_dict["po"][tick], + "vedge": self._state_dict["vedge"][tick], + "pedge": self._state_dict["pedge"][tick], + "ppedge": self._state_dict["ppedge"][tick], + "mask": mask, + "len": len(state_tick_range), + } + + return ret + + def normalize(self, feature): + if not self._normalize: + return feature + return feature * self._norm_scale + + def get_input_dim(self, agent_code): + if agent_code in self.port_code_inv_dict or agent_code == "p": + return len(self.port_features) + (self.node_cnt if self._feature_config.onehot_identity else 0) + elif agent_code in self.vessel_code_inv_dict or agent_code == "v": + return len(self.vessel_features) + (self.node_cnt if self._feature_config.onehot_identity else 0) + elif agent_code == "vedge": + # v-p edge: (arrival_time, laden to destination) + return 2 + elif agent_code == "pedge": + # p-p edge: (laden to destination, ) + return 1 + else: + raise ValueError("agent not exist!") diff --git a/examples/cim/gnn/utils.py b/examples/cim/gnn/utils.py new file mode 100644 index 000000000..4c23c10aa --- /dev/null +++ b/examples/cim/gnn/utils.py @@ -0,0 +1,266 @@ +import ast +import io +import numpy as np +import random +import os +import shutil +import sys +import yaml +from collections import defaultdict, OrderedDict + +import torch + +from maro.simulator import Env +from maro.simulator.scenarios.cim.common import Action +from maro.utils import convert_dottable, clone + + +def compute_v2p_degree_matrix(env): + """This function compute the adjacent matrix.""" + topo_config = env.configs + static_dict = env.summary["node_mapping"]["ports"] + dynamic_dict = env.summary["node_mapping"]["vessels"] + adj_matrix = np.zeros((len(dynamic_dict), len(static_dict)), dtype=np.int) + for v, vinfo in topo_config["vessels"].items(): + route_name = vinfo["route"]["route_name"] + route = topo_config["routes"][route_name] + vid = dynamic_dict[v] + for p in route: + adj_matrix[vid][static_dict[p["port_name"]]] += 1 + + return adj_matrix + + +def from_numpy(device, *np_values): + return [torch.from_numpy(v).to(device) for v in np_values] + + +def gnn_union(p, po, pedge, v, vo, vedge, p2p, ppedge, seq_mask, device): + """Union multiple graph in CIM. + + Args: + v: Numpy array of shape (seq_len, batch, v_cnt, v_dim). + vo: Numpy array of shape (batch, v_cnt, p_cnt). + vedge: Numpy array of shape (batch, v_cnt, p_cnt, e_dim). + Returns: + result (dict): The dictionary that describes the graph. + """ + seq_len, batch, v_cnt, v_dim = v.shape + _, _, p_cnt, p_dim = p.shape + + p, po, pedge, v, vo, vedge, p2p, ppedge, seq_mask = from_numpy( + device, p, po, pedge, v, vo, vedge, p2p, ppedge, seq_mask) + + batch_range = torch.arange(batch, dtype=torch.long).to(device) + # vadj.shape: (batch*v_cnt, p_cnt*) + vadj, vedge = flatten_embedding(vo, batch_range, vedge) + # vmask.shape: (batch*v_cnt, p_cnt*) + vmask = vadj == 0 + # vadj.shape: (p_cnt*, batch*v_cnt) + vadj = vadj.transpose(0, 1) + # vedge.shape: (p_cnt*, batch*v_cnt, e_dim) + vedge = vedge.transpose(0, 1) + + padj, pedge = flatten_embedding(po, batch_range, pedge) + pmask = padj == 0 + padj = padj.transpose(0, 1) + pedge = pedge.transpose(0, 1) + + p2p_adj = p2p.repeat(batch, 1, 1) + # p2p_adj.shape: (batch*p_cnt, p_cnt*) + p2p_adj, ppedge = flatten_embedding(p2p_adj, batch_range, ppedge) + # p2p_mask.shape: (batch*p_cnt, p_cnt*) + p2p_mask = p2p_adj == 0 + # p2p_adj.shape: (p_cnt*, batch*p_cnt) + p2p_adj = p2p_adj.transpose(0, 1) + ppedge = ppedge.transpose(0, 1) + + return { + "v": v, + "p": p, + "pe": { + "edge": pedge, + "adj": padj, + "mask": pmask, + }, + "ve": { + "edge": vedge, + "adj": vadj, + "mask": vmask, + }, + "ppe": { + "edge": ppedge, + "adj": p2p_adj, + "mask": p2p_mask, + }, + "mask": seq_mask, + } + + +def flatten_embedding(embedding, batch_range, edge=None): + if len(embedding.shape) == 3: + batch, x_cnt, y_cnt = embedding.shape + addon = (batch_range * y_cnt).view(batch, 1, 1) + else: + seq_len, batch, x_cnt, y_cnt = embedding.shape + addon = (batch_range * y_cnt).view(seq_len, batch, 1, 1) + + embedding_mask = embedding == 0 + embedding += addon + embedding[embedding_mask] = 0 + ret = embedding.reshape(-1, embedding.shape[-1]) + col_mask = ret.sum(dim=0) != 0 + ret = ret[:, col_mask] + if edge is None: + return ret + else: + edge = edge.reshape(-1, *edge.shape[2:])[:, col_mask, :] + return ret, edge + + +def log2json(file_path): + """load the log file as a json list.""" + with open(file_path, "r") as fp: + lines = fp.read().splitlines() + json_list = "[" + ",".join(lines) + "]" + return ast.literal_eval(json_list) + + +def decision_cnt_analysis(env, pv=False, buffer_size=8): + if not pv: + decision_cnt = [buffer_size] * len(env.node_name_mapping["static"]) + r, pa, is_done = env.step(None) + while not is_done: + decision_cnt[pa.port_idx] += 1 + action = Action(pa.vessel_idx, pa.port_idx, 0) + r, pa, is_done = env.step(action) + else: + decision_cnt = OrderedDict() + r, pa, is_done = env.step(None) + while not is_done: + if (pa.port_idx, pa.vessel_idx) not in decision_cnt: + decision_cnt[pa.port_idx, pa.vessel_idx] = buffer_size + else: + decision_cnt[pa.port_idx, pa.vessel_idx] += 1 + action = Action(pa.vessel_idx, pa.port_idx, 0) + r, pa, is_done = env.step(action) + env.reset() + return decision_cnt + + +def random_shortage(env, tick, action_dim=21): + _, pa, is_done = env.step(None) + node_cnt = len(env.summary["node_mapping"]["ports"]) + while not is_done: + """ + load, discharge = pa.action_scope.load, pa.action_scope.discharge + action_idx = np.random.randint(action_dim) - zero_idx + if action_idx < 0: + actual_action = int(1.0*action_idx/zero_idx*load) + else: + actual_action = int(1.0*action_idx/zero_idx*discharge) + """ + action = Action(pa.vessel_idx, pa.port_idx, 0) + r, pa, is_done = env.step(action) + + shs = env.snapshot_list["ports"][tick - 1:list(range(node_cnt)):"acc_shortage"] + fus = env.snapshot_list["ports"][tick - 1:list(range(node_cnt)):"acc_fulfillment"] + env.reset() + return fus - shs, np.sum(shs + fus) + + +def return_scaler(env, tick, gamma, action_dim=21): + R, tot_amount = random_shortage(env, tick, action_dim) + Rs_mean = np.mean(R) / tick / (1 - gamma) + return abs(1.0 / Rs_mean), tot_amount + + +def load_config(config_pth): + with io.open(config_pth, "r") as in_file: + raw_config = yaml.safe_load(in_file) + config = convert_dottable(raw_config) + + if config.env.seed < 0: + config.env.seed = random.randint(0, 99999) + + regularize_config(config) + return config + + +def save_config(config, config_pth): + with open(config_pth, "w") as fp: + config = dottable2dict(config) + config["env"]["exp_per_ep"] = [f"{k[0]}, {k[1]}, {d}" for k, d in config["env"]["exp_per_ep"].items()] + yaml.safe_dump(config, fp) + + +def dottable2dict(config): + if isinstance(config, float): + return str(config) + if not isinstance(config, dict): + return clone(config) + rt = {} + for k, v in config.items(): + rt[k] = dottable2dict(v) + return rt + + +def save_code(folder, save_pth): + save_path = os.path.join(save_pth, "code") + code_pth = os.path.join(os.getcwd(), folder) + shutil.copytree(code_pth, save_path) + + +def fix_seed(env, seed): + env.set_seed(seed) + np.random.seed(seed) + random.seed(seed) + + +def zero_play(**args): + env = Env(**args) + _, pa, is_done = env.step(None) + while not is_done: + action = Action(pa.vessel_idx, pa.port_idx, 0) + r, pa, is_done = env.step(action) + return env.snapshot_list + + +def regularize_config(config): + def parse_value(v): + try: + return int(v) + except ValueError: + try: + return float(v) + except ValueError: + if v == "false" or v == "False": + return False + elif v == "true" or v == "True": + return True + else: + return v + + def set_attr(config, attrs, value): + if len(attrs) == 1: + config[attrs[0]] = value + else: + set_attr(config[attrs[0]], attrs[1:], value) + + all_args = sys.argv[1:] + for i in range(len(all_args) // 2): + name = all_args[i * 2] + attrs = name[2:].split(".") + value = parse_value(all_args[i * 2 + 1]) + set_attr(config, attrs, value) + + +def analysis_speed(env): + speed_dict = defaultdict(int) + eq_speed = 0 + for ves in env.configs["vessels"].values(): + speed_dict[ves["sailing"]["speed"]] += 1 + for sp, cnt in speed_dict.items(): + eq_speed += 1.0 * cnt / sp + eq_speed = 1.0 / eq_speed + return speed_dict, eq_speed