diff --git a/maro/rl/policy/discrete_rl_policy.py b/maro/rl/policy/discrete_rl_policy.py index c5d67b7e4..50e6d8943 100644 --- a/maro/rl/policy/discrete_rl_policy.py +++ b/maro/rl/policy/discrete_rl_policy.py @@ -244,7 +244,7 @@ def get_state(self) -> dict: } def set_state(self, policy_state: dict) -> None: - self._q_net.set_state(policy_state) + self._q_net.set_state(policy_state["net"]) self._warmup = policy_state["policy"]["warmup"] self._call_count = policy_state["policy"]["call_count"]