-
Notifications
You must be signed in to change notification settings - Fork 157
initalize gnn network for CIM problem #60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
3b59231
initalize gnn netowrk for ECR problem
WenleiShi 3370f12
fix import path error and remove useless function
WenleiShi 60534e4
add the annotation for gnn and state shaper
WenleiShi a39ad4f
Merge branch 'v0.1' into v0.1_feature_ecr_gnn
WenleiShi 882ef26
rename to cim
WenleiShi 79ff455
Merge branch 'v0.1' into v0.1_feature_ecr_gnn
WenleiShi 96f058f
modify topology 22p
WenleiShi f79cebd
Merge branch 'v0.1' into v0.1_feature_ecr_gnn
WenleiShi d8f802a
Merge branch 'v0.1' into v0.1_feature_ecr_gnn
WenleiShi 44754f9
polish coding style in cim.gnn
wesley-stone 1e78d0b
run pass after polishing
WenleiShi 686cd33
change the name of simplegat to simple_transformer
WenleiShi 7c3a9cd
fix typo and suggestions
WenleiShi 35f3b39
Merge branch 'v0.1' into v0.1_feature_ecr_gnn
wesley-stone b1a400a
fix format
WenleiShi 2f3f071
fix code bugs due to refactoring mistake
WenleiShi 522eeb5
refine the code format
WenleiShi 7da898d
refine again
WenleiShi f9c779f
refine the format
WenleiShi 0fd4f5b
refine again
WenleiShi da6fa77
refine again
WenleiShi 53d5584
Merge branch 'master' into v0.1_feature_ecr_gnn
WenleiShi 40c8047
fix syntax bug
WenleiShi 4d7d0e1
Merge branch 'v0.1' into v0.1_feature_ecr_gnn
wesley-stone c86d5d9
refine again
WenleiShi ba84802
remove useless comments
WenleiShi 255ed67
remove unnecessary blank line
WenleiShi c886194
rename pending_event to decision_event
WenleiShi 669da1a
remove some code faced to the future.
WenleiShi b207483
change topology
WenleiShi 51f9568
fix bugs in learner: no training logic executed
WenleiShi 3804af3
revert the topology changes
WenleiShi ec6d790
Merge branch 'v0.1' into v0.1_feature_ecr_gnn
WenleiShi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| from maro.rl import ActionShaper | ||
|
|
||
|
|
||
| class DiscreteActionShaper(ActionShaper): | ||
| """The shaping class to transform the action in [-1, 1] to actual repositioning function.""" | ||
| def __init__(self, action_dim): | ||
| super().__init__() | ||
| self._action_dim = action_dim | ||
| self._zero_action = self._action_dim // 2 | ||
|
|
||
| def __call__(self, decision_event, model_action): | ||
| """Shaping the action in [-1,1] range to the actual repositioning function. | ||
|
|
||
| This function maps integer model action within the range of [-A, A] to actual action. We define negative actual | ||
| action as discharge resource from vessel to port and positive action as upload from port to vessel, so the | ||
| upper bound and lower bound of actual action are the resource in dynamic and static node respectively. | ||
|
|
||
| Args: | ||
| decision_event (Event): The decision event from the environment. | ||
| model_action (int): Output action, range A means the half of the agent output dim. | ||
| """ | ||
| env_action = 0 | ||
| model_action -= self._zero_action | ||
|
|
||
| action_scope = decision_event.action_scope | ||
|
|
||
| if model_action < 0: | ||
| # Discharge resource from dynamic node. | ||
| env_action = round(int(model_action) * 1.0 / self._zero_action * action_scope.load) | ||
| elif model_action == 0: | ||
| env_action = 0 | ||
| else: | ||
| # Load resource to dynamic node. | ||
| env_action = round(int(model_action) * 1.0 / self._zero_action * action_scope.discharge) | ||
| env_action = int(env_action) | ||
|
|
||
| return env_action | ||
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,179 @@ | ||
| import os | ||
|
|
||
| import torch | ||
| from torch import nn | ||
| from torch.distributions import Categorical | ||
| from torch.nn.utils import clip_grad | ||
|
|
||
| from examples.cim.gnn.utils import gnn_union | ||
|
wesley-stone marked this conversation as resolved.
|
||
| from maro.rl import AbsAlgorithm | ||
|
|
||
|
|
||
| class ActorCritic(AbsAlgorithm): | ||
| """Actor-Critic algorithm in CIM problem. | ||
|
|
||
| The vanilla ac algorithm. | ||
|
|
||
| Args: | ||
| model (nn.Module): A actor-critic module outputing both the policy network and the value network. | ||
| device (torch.device): A PyTorch device instance where the module is computed on. | ||
| p2p_adj (numpy.array): The static port-to-port adjencency matrix. | ||
| td_steps (int): The value "n" in the n-step TD algorithm. | ||
| gamma (float): The time decay. | ||
| learning_rate (float): The learning rate for the module. | ||
| entropy_factor (float): The weight of the policy"s entropy to boost exploration. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, model: nn.Module, device: torch.device, p2p_adj=None, td_steps=100, gamma=0.97, learning_rate=0.0003, | ||
| entropy_factor=0.1): | ||
| self._gamma = gamma | ||
| self._td_steps = td_steps | ||
| self._value_discount = gamma**100 | ||
| self._entropy_factor = entropy_factor | ||
| self._device = device | ||
| self._tot_batchs = 0 | ||
| self._p2p_adj = p2p_adj | ||
| super().__init__( | ||
| model_dict={"a&c": model}, optimizer_opt={"a&c": (torch.optim.Adam, {"lr": learning_rate})}, | ||
| loss_func_dict={}, hyper_params=None) | ||
|
|
||
| def choose_action(self, state: dict, p_idx: int, v_idx: int): | ||
| """Get action from the AC model. | ||
|
|
||
| Args: | ||
| state (dict): A dictionary containing the input to the module. For example: | ||
| { | ||
| "v": v, | ||
| "p": p, | ||
| "pe": { | ||
| "edge": pedge, | ||
| "adj": padj, | ||
| "mask": pmask, | ||
| }, | ||
| "ve": { | ||
| "edge": vedge, | ||
| "adj": vadj, | ||
| "mask": vmask, | ||
| }, | ||
| "ppe": { | ||
| "edge": ppedge, | ||
| "adj": p2p_adj, | ||
| "mask": p2p_mask, | ||
| }, | ||
| "mask": seq_mask, | ||
| } | ||
| p_idx (int): The identity of the port doing the action. | ||
| v_idx (int): The identity of the vessel doing the action. | ||
|
|
||
| Returns: | ||
| model_action (numpy.int64): The action returned from the module. | ||
| """ | ||
| with torch.no_grad(): | ||
| prob, _ = self._model_dict["a&c"](state, a=True, p_idx=p_idx, v_idx=v_idx) | ||
| distribution = Categorical(prob) | ||
| model_action = distribution.sample().cpu().numpy() | ||
| return model_action | ||
|
|
||
| def train(self, batch, p_idx, v_idx): | ||
| """Model training. | ||
|
|
||
| Args: | ||
| batch (dict): The dictionary of a batch of experience. For example: | ||
| { | ||
| "s": the dictionary of state, | ||
| "a": model actions in numpy array, | ||
| "R": the n-step accumulated reward, | ||
| "s"": the dictionary of the next state, | ||
| } | ||
| p_idx (int): The identity of the port doing the action. | ||
| v_idx (int): The identity of the vessel doing the action. | ||
|
|
||
| Returns: | ||
| a_loss (float): action loss. | ||
| c_loss (float): critic loss. | ||
| e_loss (float): entropy loss. | ||
| tot_norm (float): the L2 norm of the gradient. | ||
|
|
||
| """ | ||
| self._tot_batchs += 1 | ||
| item_a_loss, item_c_loss, item_e_loss = 0, 0, 0 | ||
| obs_batch = batch["s"] | ||
| action_batch = batch["a"] | ||
| return_batch = batch["R"] | ||
| next_obs_batch = batch["s_"] | ||
|
|
||
| obs_batch = gnn_union( | ||
| obs_batch["p"], obs_batch["po"], obs_batch["pedge"], obs_batch["v"], obs_batch["vo"], obs_batch["vedge"], | ||
| self._p2p_adj, obs_batch["ppedge"], obs_batch["mask"], self._device) | ||
| action_batch = torch.from_numpy(action_batch).long().to(self._device) | ||
| return_batch = torch.from_numpy(return_batch).float().to(self._device) | ||
| next_obs_batch = gnn_union( | ||
| next_obs_batch["p"], next_obs_batch["po"], next_obs_batch["pedge"], next_obs_batch["v"], | ||
| next_obs_batch["vo"], next_obs_batch["vedge"], self._p2p_adj, next_obs_batch["ppedge"], | ||
| next_obs_batch["mask"], self._device) | ||
|
|
||
| # Train actor network. | ||
| self._optimizer["a&c"].zero_grad() | ||
|
|
||
| # Every port has a value. | ||
| # values.shape: (batch, p_cnt) | ||
| probs, values = self._model_dict["a&c"](obs_batch, a=True, p_idx=p_idx, v_idx=v_idx, c=True) | ||
| distribution = Categorical(probs) | ||
| log_prob = distribution.log_prob(action_batch) | ||
| entropy_loss = distribution.entropy() | ||
|
|
||
| _, values_ = self._model_dict["a&c"](next_obs_batch, c=True) | ||
| advantage = return_batch + self._value_discount * values_.detach() - values | ||
|
|
||
| if self._entropy_factor != 0: | ||
| # actor_loss = actor_loss* torch.log(entropy_loss + np.e) | ||
| advantage[:, p_idx] += self._entropy_factor * entropy_loss.detach() | ||
|
|
||
| actor_loss = - (log_prob * torch.sum(advantage, axis=-1).detach()).mean() | ||
|
|
||
| item_a_loss = actor_loss.item() | ||
| item_e_loss = entropy_loss.mean().item() | ||
|
|
||
| # Train critic network. | ||
| critic_loss = torch.sum(advantage.pow(2), axis=1).mean() | ||
| item_c_loss = critic_loss.item() | ||
| # torch.nn.utils.clip_grad_norm_(self._critic_model.parameters(),0.5) | ||
|
wesley-stone marked this conversation as resolved.
|
||
| tot_loss = 0.1 * actor_loss + critic_loss | ||
| tot_loss.backward() | ||
| tot_norm = clip_grad.clip_grad_norm_(self._model_dict["a&c"].parameters(), 1) | ||
| self._optimizer["a&c"].step() | ||
| return item_a_loss, item_c_loss, item_e_loss, float(tot_norm) | ||
|
|
||
| def set_weights(self, weights): | ||
| self._model_dict["a&c"].load_state_dict(weights) | ||
|
|
||
| def get_weights(self): | ||
| return self._model_dict["a&c"].state_dict() | ||
|
|
||
| def _get_save_idx(self, fp_str): | ||
| return int(fp_str.split(".")[0].split("_")[0]) | ||
|
|
||
| def save_model(self, pth, id): | ||
| if not os.path.exists(pth): | ||
| os.makedirs(pth) | ||
| pth = os.path.join(pth, f"{id}_ac.pkl") | ||
| torch.save(self._model_dict["a&c"].state_dict(), pth) | ||
|
|
||
| def _set_gnn_weights(self, weights): | ||
| for key in weights: | ||
| if key in self._model_dict["a&c"].state_dict().keys(): | ||
| self._model_dict["a&c"].state_dict()[key].copy_(weights[key]) | ||
|
|
||
| def load_model(self, folder_pth, idx=-1): | ||
| if idx == -1: | ||
| fps = os.listdir(folder_pth) | ||
| fps = [f for f in fps if "ac" in f] | ||
| fps.sort(key=self._get_save_idx) | ||
| ac_pth = fps[-1] | ||
| else: | ||
| ac_pth = f"{idx}_ac.pkl" | ||
| pth = os.path.join(folder_pth, ac_pth) | ||
| with open(pth, "rb") as fp: | ||
| weights = torch.load(fp, map_location=self._device) | ||
| self._set_gnn_weights(weights) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| from collections import defaultdict | ||
|
|
||
| import numpy as np | ||
|
|
||
| from examples.cim.gnn.numpy_store import Shuffler | ||
| from maro.rl import AbsAgent | ||
| from maro.utils import DummyLogger | ||
|
|
||
|
|
||
| class TrainableAgent(AbsAgent): | ||
| def __init__(self, name, algorithm, experience_pool, logger=DummyLogger()): | ||
| self._logger = logger | ||
| super().__init__(name, algorithm, experience_pool) | ||
|
|
||
| def train(self, training_config): | ||
| loss_dict = defaultdict(list) | ||
| for j in range(training_config.shuffle_time): | ||
| shuffler = Shuffler(self._experience_pool, batch_size=training_config.batch_size) | ||
| while shuffler.has_next(): | ||
| batch = shuffler.next() | ||
| actor_loss, critic_loss, entropy_loss, tot_loss = self._algorithm.train( | ||
| batch, self._name[0], self._name[1]) | ||
| loss_dict["actor"].append(actor_loss) | ||
| loss_dict["critic"].append(critic_loss) | ||
| loss_dict["entropy"].append(entropy_loss) | ||
| loss_dict["tot"].append(tot_loss) | ||
|
|
||
| a_loss = np.mean(loss_dict["actor"]) | ||
| c_loss = np.mean(loss_dict["critic"]) | ||
| e_loss = np.mean(loss_dict["entropy"]) | ||
| tot_loss = np.mean(loss_dict["tot"]) | ||
| self._logger.debug( | ||
| f"code: {str(self._name)} \t actor: {float(a_loss)} \t critic: {float(c_loss)} \t entropy: {float(e_loss)} \ | ||
| \t tot: {float(tot_loss)}") | ||
|
|
||
| self._experience_pool.clear() | ||
| return loss_dict | ||
|
|
||
| def choose_action(self, model_state): | ||
| return self._algorithm.choose_action(model_state, self._name[0], self._name[1]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| from copy import copy | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from examples.cim.gnn.agent import TrainableAgent | ||
| from examples.cim.gnn.actor_critic import ActorCritic | ||
| from examples.cim.gnn.simple_gnn import SharedAC | ||
| from examples.cim.gnn.numpy_store import NumpyStore | ||
| from examples.cim.gnn.state_shaper import GNNStateShaper | ||
| from maro.rl import AbsAgentManager, AgentMode | ||
| from maro.utils import DummyLogger | ||
|
|
||
|
|
||
| class SimpleAgentManger(AbsAgentManager): | ||
| def __init__( | ||
| self, name, agent_id_list, port_code_list, vessel_code_list, demo_env, state_shaper: GNNStateShaper, | ||
| logger=DummyLogger()): | ||
| super().__init__( | ||
| name, AgentMode.TRAIN, agent_id_list, state_shaper=state_shaper, action_shaper=None, | ||
| experience_shaper=None, explorer=None) | ||
| self.port_code_list = copy(port_code_list) | ||
| self.vessel_code_list = copy(vessel_code_list) | ||
| self.demo_env = demo_env | ||
| self._logger = logger | ||
|
|
||
| def assemble(self, config): | ||
| v_dim, vedge_dim = self._state_shaper.get_input_dim("v"), self._state_shaper.get_input_dim("vedge") | ||
| p_dim, pedge_dim = self._state_shaper.get_input_dim("p"), self._state_shaper.get_input_dim("pedge") | ||
|
|
||
| self.device = torch.device(config.training.device) | ||
| self._logger.info(config.training.device) | ||
| ac_model = SharedAC( | ||
| p_dim, pedge_dim, v_dim, vedge_dim, config.model.tick_buffer, config.model.action_dim).to(self.device) | ||
|
|
||
| value_dict = { | ||
| ("s", "v"): | ||
| ( | ||
| (config.model.tick_buffer, len(self.vessel_code_list), self._state_shaper.get_input_dim("v")), | ||
| np.float32, False), | ||
| ("s", "p"): | ||
| ( | ||
| (config.model.tick_buffer, len(self.port_code_list), self._state_shaper.get_input_dim("p")), | ||
| np.float32, False), | ||
| ("s", "vo"): ((len(self.vessel_code_list), len(self.port_code_list)), np.int64, True), | ||
| ("s", "po"): ((len(self.port_code_list), len(self.vessel_code_list)), np.int64, True), | ||
| ("s", "vedge"): | ||
| ( | ||
| (len(self.vessel_code_list), len(self.port_code_list), self._state_shaper.get_input_dim("vedge")), | ||
| np.float32, True), | ||
| ("s", "pedge"): | ||
| ( | ||
| (len(self.port_code_list), len(self.vessel_code_list), self._state_shaper.get_input_dim("vedge")), | ||
| np.float32, True), | ||
| ("s", "ppedge"): | ||
| ( | ||
| (len(self.port_code_list), len(self.port_code_list), self._state_shaper.get_input_dim("pedge")), | ||
| np.float32, True), | ||
| ("s", "mask"): ((config.model.tick_buffer, ), np.bool, True), | ||
|
|
||
| ("s_", "v"): | ||
| ( | ||
| (config.model.tick_buffer, len(self.vessel_code_list), self._state_shaper.get_input_dim("v")), | ||
| np.float32, False), | ||
| ("s_", "p"): | ||
| ( | ||
| (config.model.tick_buffer, len(self.port_code_list), self._state_shaper.get_input_dim("p")), | ||
| np.float32, False), | ||
| ("s_", "vo"): ((len(self.vessel_code_list), len(self.port_code_list)), np.int64, True), | ||
| ("s_", "po"): | ||
| ( | ||
| (len(self.port_code_list), len(self.vessel_code_list)), np.int64, True), | ||
| ("s_", "vedge"): | ||
| ( | ||
| (len(self.vessel_code_list), len(self.port_code_list), self._state_shaper.get_input_dim("vedge")), | ||
| np.float32, True), | ||
| ("s_", "pedge"): | ||
| ( | ||
| (len(self.port_code_list), len(self.vessel_code_list), self._state_shaper.get_input_dim("vedge")), | ||
| np.float32, True), | ||
| ("s_", "ppedge"): | ||
| ( | ||
| (len(self.port_code_list), len(self.port_code_list), self._state_shaper.get_input_dim("pedge")), | ||
| np.float32, True), | ||
| ("s_", "mask"): ((config.model.tick_buffer, ), np.bool, True), | ||
|
|
||
| # To identify one dimension variable. | ||
| ("R",): ((len(self.port_code_list), ), np.float32, True), | ||
| ("a",): (tuple(), np.int64, True), | ||
| } | ||
|
|
||
| self._algorithm = ActorCritic( | ||
| ac_model, self.device, td_steps=config.training.td_steps, p2p_adj=self._state_shaper.p2p_static_graph, | ||
| gamma=config.training.gamma, learning_rate=config.training.learning_rate) | ||
|
|
||
| for agent_id, cnt in config.env.exp_per_ep.items(): | ||
| experience_pool = NumpyStore(value_dict, config.training.parallel_cnt * config.training.train_freq * cnt) | ||
| self._agent_dict[agent_id] = TrainableAgent(agent_id, self._algorithm, experience_pool, self._logger) | ||
|
|
||
| def choose_action(self, agent_id, state): | ||
| return self._agent_dict[agent_id].choose_action(state) | ||
|
|
||
| def load_models_from_files(self, model_pth): | ||
| self._algorithm.load_model(model_pth) | ||
|
|
||
| def train(self, training_config): | ||
| for agent in self._agent_dict.values(): | ||
| agent.train(training_config) | ||
|
|
||
| def store_experiences(self, experiences): | ||
| for code, exp_list in experiences.items(): | ||
| self._agent_dict[code].store_experiences(exp_list) | ||
|
|
||
| def save_model(self, pth, id): | ||
| self._algorithm.save_model(pth, id) | ||
|
|
||
| def load_model(self, pth): | ||
| self._algorithm.load_model(pth) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.