diff --git a/ding/envs/env_manager/env_supervisor.py b/ding/envs/env_manager/env_supervisor.py index b3e1c17fcf..a12a20ed0d 100644 --- a/ding/envs/env_manager/env_supervisor.py +++ b/ding/envs/env_manager/env_supervisor.py @@ -61,6 +61,7 @@ def __init__( episode_num: int = float("inf"), shared_memory: bool = True, copy_on_get: bool = True, + return_original_data: bool = False, **kwargs ) -> None: """ @@ -78,6 +79,9 @@ def __init__( - retry_waiting_time (:obj:`Optional[float]`): Wait time on each retry. - shared_memory (:obj:`bool`): Use shared memory in multiprocessing. - copy_on_get (:obj:`bool`): Use copy on get in multiprocessing. + - return_original_data (:obj:`bool`): Return original observation, + so that the attribute self._ready_obs is not a tnp.array but only the original observation, + and the property self.ready_obs is a dict in which the key is the env_id. """ if kwargs: logging.warning("Unknown parameters on env supervisor: {}".format(kwargs)) @@ -122,6 +126,7 @@ def __init__( self._retry_waiting_time = retry_waiting_time self._env_replay_path = None self._episode_num = episode_num + self._return_original_data = return_original_data self._init_states() def _init_states(self): @@ -255,6 +260,8 @@ def ready_obs(self) -> tnp.array: >>> timesteps = env_manager.step(action) """ active_env = [i for i, s in self._env_states.items() if s == EnvState.RUN] + if self._return_original_data: + return {i: self._ready_obs[i] for i in active_env} active_env.sort() obs = [self._ready_obs.get(i) for i in active_env] if len(obs) == 0: @@ -409,16 +416,17 @@ def _recv_step_callback( remain_payloads[p.req_id] = p # make the type and content of key as similar as identifier, # in order to call them as attribute (e.g. timestep.xxx), such as ``TimeLimit.truncated`` in cartpole info - info = make_key_as_identifier(info) - payload.data = tnp.array( - { - 'obs': obs, - 'reward': reward, - 'done': done, - 'info': info, - 'env_id': payload.proc_id - } - ) + if not self._return_original_data: + info = make_key_as_identifier(info) + payload.data = tnp.array( + { + 'obs': obs, + 'reward': reward, + 'done': done, + 'info': info, + 'env_id': payload.proc_id + } + ) self._ready_obs[payload.proc_id] = obs return payload diff --git a/ding/framework/__init__.py b/ding/framework/__init__.py index 4a19f56316..11b247d512 100644 --- a/ding/framework/__init__.py +++ b/ding/framework/__init__.py @@ -1,5 +1,6 @@ -from .context import Context, OnlineRLContext, OfflineRLContext +from .context import Context, OnlineRLContext, OfflineRLContext, BattleContext from .task import Task, task from .parallel import Parallel from .event_loop import EventLoop +from .event_enum import EventEnum from .supervisor import Supervisor diff --git a/ding/framework/context.py b/ding/framework/context.py index 949e0e7d01..97e178a5bd 100644 --- a/ding/framework/context.py +++ b/ding/framework/context.py @@ -75,3 +75,38 @@ def __init__(self, *args, **kwargs) -> None: self.last_eval_iter = -1 self.keep('train_iter', 'last_eval_iter') + + +class BattleContext(Context): + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.__dict__ = self + # collect target paras + self.n_episode = None + + #collect process paras + self.env_episode = 0 + self.env_step = 0 + self.total_envstep_count = 0 + self.train_iter = 0 + self.collect_kwargs = {} + self.current_policies = [] + + #job paras + self.player_id_list = [] + self.job_finish = False + + #data + self.obs = None + self.actions = None + self.inference_output = {} + self.trajectories = None + + #Return data paras + self.episodes = [] + self.episode_info = [] + self.trajectories_list = [] + self.train_data = None + + self.keep('train_iter') diff --git a/ding/framework/event_enum.py b/ding/framework/event_enum.py new file mode 100644 index 0000000000..e7f555b548 --- /dev/null +++ b/ding/framework/event_enum.py @@ -0,0 +1,16 @@ +from enum import Enum, unique + + +@unique +class EventEnum(str, Enum): + # events emited by coordinators + COORDINATOR_DISPATCH_ACTOR_JOB = "on_coordinator_dispatch_actor_job_{actor_id}" + + # events emited by learners + LEARNER_SEND_MODEL = "on_learner_send_model" + LEARNER_SEND_META = "on_learner_send_meta" + + # events emited by actors + ACTOR_GREETING = "on_actor_greeting" + ACTOR_SEND_DATA = "on_actor_send_meta_player_{player}" + ACTOR_FINISH_JOB = "on_actor_finish_job" diff --git a/ding/framework/event_loop.py b/ding/framework/event_loop.py index b5f58720a1..a4665ca01b 100644 --- a/ding/framework/event_loop.py +++ b/ding/framework/event_loop.py @@ -1,3 +1,4 @@ +import re from collections import defaultdict from typing import Callable, Optional from concurrent.futures import ThreadPoolExecutor @@ -23,6 +24,12 @@ def on(self, event: str, fn: Callable) -> None: - event (:obj:`str`): Event name. - fn (:obj:`Callable`): The function. """ + # check if the event name contains unfilled parameters. + params = re.findall(r"\{(.*?)\}", event) + if params: + raise ValueError( + "Event name missing parameters: {}. Please use String.format() to fill up".format(", ".join(params)) + ) self._listeners[event].append(fn) def off(self, event: str, fn: Optional[Callable] = None) -> None: @@ -65,6 +72,12 @@ def emit(self, event: str, *args, **kwargs) -> None: """ if self._exception: raise self._exception + # check if the event name contains unfilled parameters. + params = re.findall(r"\{(.*?)\}", event) + if params: + raise ValueError( + "Event name missing parameters: {}. Please use String.format() to fill up".format(", ".join(params)) + ) if self._active: self._thread_pool.submit(self._trigger, event, *args, **kwargs) diff --git a/ding/framework/middleware/__init__.py b/ding/framework/middleware/__init__.py index a2c428932c..12484568c6 100644 --- a/ding/framework/middleware/__init__.py +++ b/ding/framework/middleware/__init__.py @@ -1,4 +1,7 @@ from .functional import * -from .collector import StepCollector, EpisodeCollector +from .collector import StepCollector, EpisodeCollector, BattleStepCollector from .learner import OffPolicyLearner, HERLearner from .ckpt_handler import CkptSaver +from .league_actor import StepLeagueActor +from .league_coordinator import LeagueCoordinator +from .league_learner_communicator import LeagueLearnerCommunicator, LearnerModel diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py index fa70a00766..8d16000bef 100644 --- a/ding/framework/middleware/collector.py +++ b/ding/framework/middleware/collector.py @@ -1,13 +1,18 @@ -from typing import TYPE_CHECKING, Callable, List from easydict import EasyDict +from typing import Dict, TYPE_CHECKING, Callable +import time +from ditk import logging from ding.policy import get_random_policy from ding.envs import BaseEnvManager +from ding.utils import log_every_sec from ding.framework import task -from .functional import inferencer, rolloutor, TransitionList +from ding.framework.middleware.functional import PlayerModelInfo +from .functional import inferencer, rolloutor, TransitionList, BattleTransitionList, \ + battle_inferencer, battle_rolloutor if TYPE_CHECKING: - from ding.framework import OnlineRLContext + from ding.framework import OnlineRLContext, BattleContext class StepCollector: @@ -113,4 +118,126 @@ def __call__(self, ctx: "OnlineRLContext") -> None: break -# TODO battle collector +WAIT_MODEL_TIME = float('inf') + + +class BattleStepCollector: + + def __init__( + self, + cfg: EasyDict, + env: BaseEnvManager, + unroll_len: int, + model_dict: Dict, + model_info_dict: Dict, + player_policy_collect_dict: Dict, + agent_num: int, + last_step_fn: Callable = None + ): + self.cfg = cfg + self.end_flag = False + # self._reset(env) + self.env = env + self.env_num = self.env.env_num + + self.total_envstep_count = 0 + self.unroll_len = unroll_len + self.model_dict = model_dict + self.model_info_dict = model_info_dict + self.player_policy_collect_dict = player_policy_collect_dict + self.agent_num = agent_num + + self._battle_inferencer = task.wrap(battle_inferencer(self.cfg, self.env)) + self._transitions_list = [ + BattleTransitionList(self.env.env_num, self.unroll_len, last_step_fn) for _ in range(self.agent_num) + ] + self._battle_rolloutor = task.wrap( + battle_rolloutor(self.cfg, self.env, self._transitions_list, self.model_info_dict) + ) + + def __del__(self) -> None: + """ + Overview: + Execute the close command and close the collector. __del__ is automatically called to \ + destroy the collector instance when the collector finishes its work + """ + if self.end_flag: + return + self.end_flag = True + self.env.close() + + def _update_policies(self, player_id_set) -> None: + for player_id in player_id_set: + # for this player, if in the beginning of actor's lifetime, + # actor didn't recieve any new model, use initial model instead. + if self.model_info_dict.get(player_id) is None: + self.model_info_dict[player_id] = PlayerModelInfo( + get_new_model_time=time.time(), update_new_model_time=None + ) + + update_player_id_set = set() + for player_id in player_id_set: + if 'historical' not in player_id: + update_player_id_set.add(player_id) + while True: + time_now = time.time() + time_list = [ + time_now - self.model_info_dict[player_id].get_new_model_time for player_id in update_player_id_set + ] + if any(x >= WAIT_MODEL_TIME for x in time_list): + for index, player_id in enumerate(update_player_id_set): + if time_list[index] >= WAIT_MODEL_TIME: + log_every_sec( + logging.WARNING, 5, + 'In actor {}, model for {} is not updated for {} senconds, and need new model'.format( + task.router.node_id, player_id, time_list[index] + ) + ) + time.sleep(1) + else: + break + + for player_id in update_player_id_set: + if self.model_dict.get(player_id) is None: + continue + else: + learner_model = self.model_dict.get(player_id) + policy = self.player_policy_collect_dict.get(player_id) + assert policy, "for player{}, policy should have been initialized already" + # update policy model + policy.load_state_dict(learner_model.state_dict) + self.model_info_dict[player_id].update_new_model_time = time.time() + self.model_info_dict[player_id].update_train_iter = learner_model.train_iter + self.model_dict[player_id] = None + + def __call__(self, ctx: "BattleContext") -> None: + + ctx.total_envstep_count = self.total_envstep_count + old = ctx.env_step + + while True: + if self.env.closed: + self.env.launch() + for env_id in range(self.env_num): + for policy in ctx.current_policies: + policy.reset([env_id]) + self._update_policies(set(ctx.player_id_list)) + self._battle_inferencer(ctx) + self._battle_rolloutor(ctx) + + self.total_envstep_count = ctx.total_envstep_count + + only_finished = True if ctx.env_episode >= ctx.n_episode else False + if (self.unroll_len > 0 and ctx.env_step - old >= self.unroll_len) or ctx.env_episode >= ctx.n_episode: + for transitions in self._transitions_list: + trajectories = transitions.to_trajectories(only_finished=only_finished) + ctx.trajectories_list.append(trajectories) + if ctx.env_episode >= ctx.n_episode: + self.env.close() + ctx.job_finish = True + for transitions in self._transitions_list: + transitions.clear() + break + + +# TODO BattleEpisodeCollector diff --git a/ding/framework/middleware/functional/__init__.py b/ding/framework/middleware/functional/__init__.py index 5f8ecdfb7d..d3e124d2c3 100644 --- a/ding/framework/middleware/functional/__init__.py +++ b/ding/framework/middleware/functional/__init__.py @@ -1,7 +1,8 @@ from .trainer import trainer, multistep_trainer from .data_processor import offpolicy_data_fetcher, data_pusher, offline_data_fetcher, offline_data_saver, \ sqil_data_pusher -from .collector import inferencer, rolloutor, TransitionList +from .collector import inferencer, rolloutor, TransitionList, BattleTransitionList, \ + battle_inferencer, battle_rolloutor from .evaluator import interaction_evaluator from .termination_checker import termination_checker from .ctx_helper import final_ctx_saver @@ -10,3 +11,4 @@ from .explorer import eps_greedy_handler, eps_greedy_masker from .advantage_estimator import gae_estimator from .enhancer import reward_estimator, her_data_enhancer, nstep_reward_enhancer +from .actor_data import ActorData, ActorDataMeta, ActorEnvTrajectories, PlayerModelInfo diff --git a/ding/framework/middleware/functional/actor_data.py b/ding/framework/middleware/functional/actor_data.py new file mode 100644 index 0000000000..f86cd12325 --- /dev/null +++ b/ding/framework/middleware/functional/actor_data.py @@ -0,0 +1,30 @@ +from typing import Any, List +from dataclasses import dataclass, field + +#TODO(zms): simplify fields + + +@dataclass +class ActorDataMeta: + player_total_env_step: int = 0 + actor_id: int = 0 + send_wall_time: float = 0.0 + + +@dataclass +class ActorEnvTrajectories: + env_id: int = 0 + trajectories: List = field(default_factory=[]) + + +@dataclass +class ActorData: + meta: ActorDataMeta + train_data: List[ActorEnvTrajectories] = field(default_factory=[]) + + +@dataclass +class PlayerModelInfo: + get_new_model_time: float + update_new_model_time: float + update_train_iter: int = 0 diff --git a/ding/framework/middleware/functional/collector.py b/ding/framework/middleware/functional/collector.py index fa47e7cf83..ab59b3c02b 100644 --- a/ding/framework/middleware/functional/collector.py +++ b/ding/framework/middleware/functional/collector.py @@ -1,13 +1,22 @@ -from typing import TYPE_CHECKING, Callable, List, Tuple, Any +from typing import TYPE_CHECKING, Optional, Callable, List, Tuple, Any, Dict from easydict import EasyDict from functools import reduce import treetensor.torch as ttorch -from ding.envs import BaseEnvManager +from ding.envs.env_manager.base_env_manager import BaseEnvManager +from ding.envs.env.base_env import BaseEnvTimestep from ding.policy import Policy -from ding.torch_utils import to_ndarray +import torch +from ding.utils import dicts_to_lists +from ding.torch_utils import to_tensor, to_ndarray +from ding.framework import task -if TYPE_CHECKING: - from ding.framework import OnlineRLContext +# if TYPE_CHECKING: +from ding.framework import OnlineRLContext, BattleContext +from collections import deque +from ding.framework.middleware.functional.actor_data import ActorEnvTrajectories +from copy import deepcopy + +from ditk import logging class TransitionList: @@ -45,6 +54,125 @@ def clear(self): item.clear() +class BattleTransitionList: + + def __init__(self, env_num: int, unroll_len: int, last_step_fn: Callable = None) -> None: + # for each env, we have a deque to buffer episodes, + # and a deque to tell each episode is finished or not + self.env_num = env_num + self._transitions = [deque() for _ in range(env_num)] + self._done_episode = [deque() for _ in range(env_num)] + self._unroll_len = unroll_len + self._last_step_fn = last_step_fn + + def get_env_trajectories(self, env_id: int, only_finished: bool = False) -> List[List]: + trajectories = [] + if len(self._transitions[env_id]) == 0: + # if we have no episode for this env, we return an empty list + return trajectories + while len(self._transitions[env_id]) > 0: + # Every time we check if oldest episode is done, + # if is done, we cut the episode to trajectories + # and finally drop this episode + if self._done_episode[env_id][0] is False: + break + oldest_episode = self._transitions[env_id].popleft() + self._done_episode[env_id].popleft() + trajectories += self._cut_trajectory_from_episode(oldest_episode) + oldest_episode.clear() + + if not only_finished and len(self._transitions[env_id]) == 1 and self._done_episode[env_id][0] is False: + # If last episode is not done, we only cut the trajectories till the Trajectory(t-1) (not including) + # This is because we need Trajectory(t-1) to fill up Trajectory(t) if in Trajectory(t) this episode is done + tail_idx = max( + 0, ((len(self._transitions[env_id][0]) - self._unroll_len) // self._unroll_len) * self._unroll_len + ) + trajectories += self._cut_trajectory_from_episode(self._transitions[env_id][0][:tail_idx]) + self._transitions[env_id][0] = self._transitions[env_id][0][tail_idx:] + + return trajectories + + def to_trajectories(self, only_finished: bool = False) -> List[ActorEnvTrajectories]: + all_env_data = [] + for env_id in range(self.env_num): + trajectories = self.get_env_trajectories(env_id, only_finished=only_finished) + if len(trajectories) > 0: + all_env_data.append(ActorEnvTrajectories(env_id=env_id, trajectories=trajectories)) + return all_env_data + + def _cut_trajectory_from_episode(self, episode: list) -> List[List]: + # first we cut complete trajectories (list of transitions whose length equal to unroll_len) + # then we gather the transitions in the tail of episode, + # and fill up the trajectory with the tail transitions in Trajectory(t-1) + # If we don't have Trajectory(t-1), i.e. the length of the whole episode is smaller than unroll_len, + # we fill up the trajectory with the first element of episode. + return_episode = [] + i = 0 + num_complele_trajectory, num_tail_transitions = divmod(len(episode), self._unroll_len) + for i in range(num_complele_trajectory): + trajectory = episode[i * self._unroll_len:(i + 1) * self._unroll_len] + if self._last_step_fn: + last_step = deepcopy(trajectory[-1]) + last_step = self._last_step_fn(last_step) + trajectory.append(last_step) + return_episode.append(trajectory) + + if num_tail_transitions > 0: + trajectory = episode[-self._unroll_len:] + if len(trajectory) < self._unroll_len: + initial_elements = [] + for _ in range(self._unroll_len - len(trajectory)): + initial_elements.append(trajectory[0]) + trajectory = initial_elements + trajectory + if self._last_step_fn: + last_step = deepcopy(trajectory[-1]) + last_step = self._last_step_fn(last_step) + trajectory.append(last_step) + return_episode.append(trajectory) + + return return_episode # list of trajectories + + def clear_newest_episode(self, env_id: int, before_append=False) -> None: + # Call this method when env.step raise some error + + # If call this method before append, and the last episode of this env is done, + # it means that the env had some error at the first step of the newest episode, + # and we should not delete the last episode because it is a normal episode. + if before_append is True and len(self._done_episode[env_id]) > 0 and self._done_episode[env_id][-1] == True: + return 0 + if len(self._transitions[env_id]) > 0: + newest_episode = self._transitions[env_id].pop() + len_newest_episode = len(newest_episode) + newest_episode.clear() + self._done_episode[env_id].pop() + return len_newest_episode + else: + return 0 + + def append(self, env_id: int, transition: Any) -> bool: + # If previous episode is done, we create a new episode + if len(self._done_episode[env_id]) == 0 or self._done_episode[env_id][-1] is True: + self._transitions[env_id].append([]) + self._done_episode[env_id].append(False) + self._transitions[env_id][-1].append(transition) + if transition.done: + self._done_episode[env_id][-1] = True + if len(self._transitions[env_id][-1]) < self._unroll_len: + logging.warning( + 'The length of the newest finished episode in node {}, env {}, is {}, '\ + 'which is shorter than unroll_len: {}, and need to be dropped' + .format(task.router.node_id, env_id, len(self._transitions[env_id][-1]), self._unroll_len) + ) + return False + return True + + def clear(self) -> None: + for item in self._transitions: + item.clear() + for item in self._done_episode: + item.clear() + + def inferencer(cfg: EasyDict, policy: Policy, env: BaseEnvManager) -> Callable: """ Overview: @@ -132,3 +260,76 @@ def _rollout(ctx: "OnlineRLContext"): # TODO log return _rollout + + +def battle_inferencer(cfg: EasyDict, env: BaseEnvManager): + + def _battle_inferencer(ctx: "BattleContext"): + # Get current env obs. + obs = env.ready_obs + + # Policy forward. + if cfg.transform_obs: + obs = to_tensor(obs, dtype=torch.float32) + obs = dicts_to_lists(obs) + inference_output = [p.forward(obs[i], **ctx.collect_kwargs) for i, p in enumerate(ctx.current_policies)] + ctx.obs = obs + ctx.inference_output = inference_output + # Interact with env. + actions = {} + for env_id in range(env.env_num): + actions[env_id] = [] + for output in inference_output: + actions[env_id].append(output[env_id]['action']) + ctx.actions = to_ndarray(actions) + + return _battle_inferencer + + +def battle_rolloutor(cfg: EasyDict, env: BaseEnvManager, transitions_list: List, model_info_dict: Dict): + + def _battle_rolloutor(ctx: "BattleContext"): + timesteps = env.step(ctx.actions) + ctx.total_envstep_count += len(timesteps) + ctx.env_step += len(timesteps) + + if isinstance(timesteps, list): + new_time_steps = {} + for env_id, timestep in enumerate(timesteps): + new_time_steps[env_id] = timestep + timesteps = new_time_steps + + for env_id, timestep in timesteps.items(): + if isinstance(timestep.info, dict) and timestep.info.get('abnormal'): + for policy_id, policy in enumerate(ctx.current_policies): + transitions_list[policy_id].clear_newest_episode(env_id, before_append=True) + policy.reset([env_id]) + continue + + episode_long_enough = True + for policy_id, policy in enumerate(ctx.current_policies): + policy_timestep_data = [d[policy_id] if not isinstance(d, bool) else d for d in timestep] + policy_timestep = type(timestep)(*policy_timestep_data) + transition = policy.process_transition( + ctx.obs[policy_id][env_id], ctx.inference_output[policy_id][env_id], policy_timestep + ) + transition = ttorch.as_tensor(transition) + transition.collect_train_iter = ttorch.as_tensor( + [model_info_dict[ctx.player_id_list[policy_id]].update_train_iter] + ) + + episode_long_enough = episode_long_enough and transitions_list[policy_id].append(env_id, transition) + + if timestep.done: + for policy_id, policy in enumerate(ctx.current_policies): + policy.reset([env_id]) + ctx.episode_info[policy_id].append(timestep.info[policy_id]) + + if not episode_long_enough: + for policy_id, _ in enumerate(ctx.current_policies): + transitions_list[policy_id].clear_newest_episode(env_id) + ctx.episode_info[policy_id].pop() + elif timestep.done: + ctx.env_episode += 1 + + return _battle_rolloutor diff --git a/ding/framework/middleware/functional/data_processor.py b/ding/framework/middleware/functional/data_processor.py index 38b38bb922..c4bc649386 100644 --- a/ding/framework/middleware/functional/data_processor.py +++ b/ding/framework/middleware/functional/data_processor.py @@ -4,6 +4,7 @@ import torch from ding.data import Buffer, Dataset, DataLoader, offline_data_save_type from ding.data.buffer.middleware import PriorityExperienceReplay +from ding.utils.sparse_logging import log_every_sec if TYPE_CHECKING: from ding.framework import OnlineRLContext, OfflineRLContext @@ -111,7 +112,8 @@ def _fetch(ctx: "OnlineRLContext"): assert buffered_data is not None except (ValueError, AssertionError): # You can modify data collect config to avoid this warning, e.g. increasing n_sample, n_episode. - logging.warning( + log_every_sec( + logging.WARNING, 10, "Replay buffer's data is not enough to support training, so skip this training for waiting more data." ) ctx.train_data = None diff --git a/ding/framework/middleware/functional/trainer.py b/ding/framework/middleware/functional/trainer.py index 5c7f99467b..2965a6dcaf 100644 --- a/ding/framework/middleware/functional/trainer.py +++ b/ding/framework/middleware/functional/trainer.py @@ -6,7 +6,7 @@ from ding.framework import task if TYPE_CHECKING: - from ding.framework import OnlineRLContext, OfflineRLContext + from ding.framework import OnlineRLContext, OfflineRLContext, BattleContext def trainer(cfg: EasyDict, policy: Policy) -> Callable: @@ -18,7 +18,7 @@ def trainer(cfg: EasyDict, policy: Policy) -> Callable: - policy (:obj:`Policy`): The policy to be trained in step-by-step mode. """ - def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]): + def _train(ctx: Union["OnlineRLContext", "OfflineRLContext", "BattleContext"]): """ Input of ctx: - train_data (:obj:`Dict`): The data used to update the network. It will train only if \ diff --git a/ding/framework/middleware/league_actor.py b/ding/framework/middleware/league_actor.py new file mode 100644 index 0000000000..84b0832157 --- /dev/null +++ b/ding/framework/middleware/league_actor.py @@ -0,0 +1,228 @@ +from typing import TYPE_CHECKING, Dict, Callable +from threading import Lock +import queue +from easydict import EasyDict +import time +from ditk import logging +import torch +import gc + +from ding.policy import Policy +from ding.framework import task, EventEnum +from ding.framework.middleware import BattleStepCollector +from ding.framework.middleware.functional import ActorData, ActorDataMeta, PlayerModelInfo +from ding.league.player import PlayerMeta +from ding.utils.sparse_logging import log_every_sec + +if TYPE_CHECKING: + from ding.league.v2.base_league import Job + from ding.framework import BattleContext + from ding.framework.middleware.league_learner_communicator import LearnerModel + + +class StepLeagueActor: + + def __init__(self, cfg: EasyDict, env_fn: Callable, policy_fn: Callable, last_step_fn: Callable = None): + self.cfg = cfg + self.env_fn = env_fn + self.env_num = env_fn().env_num + self.policy_fn = policy_fn + self.unroll_len = self.cfg.policy.collect.unroll_len + self._collectors: Dict[str, BattleStepCollector] = {} + self.player_policy_dict: Dict[str, "Policy"] = {} + self.player_policy_collect_dict: Dict[str, "Policy.collect_function"] = {} + + task.on(EventEnum.COORDINATOR_DISPATCH_ACTOR_JOB.format(actor_id=task.router.node_id), self._on_league_job) + task.on(EventEnum.LEARNER_SEND_MODEL, self._on_learner_model) + + self.job_queue = queue.Queue() + self.model_dict = {} + self.model_dict_lock = Lock() + self.model_info_dict = {} + self.agent_num = 2 + self.last_step_fn = last_step_fn + + self.traj_num = 0 + self.total_time = 0 + self.total_episode_num = 0 + + def _on_learner_model(self, learner_model: "LearnerModel"): + """ + If get newest learner model, put it inside model_queue. + """ + log_every_sec( + logging.INFO, 5, '[Actor {}] recieved model {}'.format(task.router.node_id, learner_model.player_id) + ) + with self.model_dict_lock: + self.model_dict[learner_model.player_id] = learner_model + if self.model_info_dict.get(learner_model.player_id): + self.model_info_dict[learner_model.player_id].get_new_model_time = time.time() + self.model_info_dict[learner_model.player_id].update_new_model_time = None + else: + self.model_info_dict[learner_model.player_id] = PlayerModelInfo( + get_new_model_time=time.time(), update_new_model_time=None + ) + + def _on_league_job(self, job: "Job"): + """ + Deal with job distributed by coordinator, put it inside job_queue. + """ + self.job_queue.put(job) + + def _get_collector(self, player_id: str): + if self._collectors.get(player_id): + return self._collectors.get(player_id) + cfg = self.cfg + env = self.env_fn() + collector = task.wrap( + BattleStepCollector( + cfg.policy.collect.collector, env, self.unroll_len, self.model_dict, self.model_info_dict, + self.player_policy_collect_dict, self.agent_num, self.last_step_fn + ) + ) + self._collectors[player_id] = collector + return collector + + def _get_policy(self, player: "PlayerMeta", duplicate: bool = False) -> "Policy.collect_function": + player_id = player.player_id + + if self.player_policy_collect_dict.get(player_id): + player_policy_collect_mode = self.player_policy_collect_dict.get(player_id) + if duplicate is False: + return player_policy_collect_mode + else: + player_policy = self.player_policy_dict.get(player_id) + duplicate_policy: "Policy.collect_function" = self.policy_fn() + del duplicate_policy._collect_model + duplicate_policy._collect_model = player_policy._collect_model + if getattr(player_policy, 'teacher_model') and getattr(duplicate_policy, 'teacher_model'): + del duplicate_policy.teacher_model + duplicate_policy.teacher_model = player_policy.teacher_model + return duplicate_policy.collect_mode + else: + policy: "Policy.collect_function" = self.policy_fn() + self.player_policy_dict[player_id] = policy + + policy_collect_mode = policy.collect_mode + self.player_policy_collect_dict[player_id] = policy_collect_mode + # TODO(zms): not only historical players, but also other players should + # update the policies to the checkpoint in job + if "historical" in player.player_id: + policy_collect_mode.load_state_dict(player.checkpoint.load()) + + return policy_collect_mode + + def _get_job(self): + if self.job_queue.empty(): + task.emit(EventEnum.ACTOR_GREETING, task.router.node_id) + job = None + + try: + job = self.job_queue.get(timeout=10) + except queue.Empty: + logging.warning("[Actor {}] no Job got from coordinator.".format(task.router.node_id)) + + return job + + def _get_current_policies(self, job): + current_policies = [] + main_player: "PlayerMeta" = None + player_set = set() + for player in job.players: + if player.player_id not in player_set: + current_policies.append(self._get_policy(player, duplicate=False)) + player_set.add(player.player_id) + else: + current_policies.append(self._get_policy(player, duplicate=True)) + if player.player_id == job.launch_player: + main_player = player + assert main_player, "[Actor {}] cannot find active player.".format(task.router.node_id) + assert current_policies, "[Actor {}] current_policies should not be None".format(task.router.node_id) + + return main_player, current_policies + + def __call__(self, ctx: "BattleContext"): + + job = self._get_job() + if job is None: + return + print('[Actor {}] recieve job {}'.format(task.router.node_id, job)) + log_every_sec( + logging.INFO, 5, '[Actor {}] job of player {} begins.'.format(task.router.node_id, job.launch_player) + ) + + ctx.player_id_list = [player.player_id for player in job.players] + main_player_idx = [idx for idx, player in enumerate(job.players) if player.player_id == job.launch_player] + self.agent_num = len(job.players) + collector = self._get_collector(job.launch_player) + + _, ctx.current_policies = self._get_current_policies(job) + + ctx.n_episode = self.cfg.policy.collect.n_episode + assert ctx.n_episode >= self.env_num, "[Actor {}] Please make sure n_episode >= env_num".format( + task.router.node_id + ) + + ctx.n_episode = self.cfg.policy.collect.n_episode + assert ctx.n_episode >= self.env_num, "Please make sure n_episode >= env_num" + + ctx.episode_info = [[] for _ in range(self.agent_num)] + + while True: + time_begin = time.time() + collector(ctx) + + if ctx.job_finish is True: + logging.info('[Actor {}] finish current job !'.format(task.router.node_id)) + + for idx in main_player_idx: + if not job.is_eval and len(ctx.trajectories_list[idx]) > 0: + trajectories = ctx.trajectories_list[idx] + self.traj_num += len(trajectories) + meta_data = ActorDataMeta( + player_total_env_step=ctx.total_envstep_count, + actor_id=task.router.node_id, + send_wall_time=time.time() + ) + actor_data = ActorData(meta=meta_data, train_data=trajectories) + task.emit(EventEnum.ACTOR_SEND_DATA.format(player=job.launch_player), actor_data) + + ctx.trajectories_list = [] + + time_end = time.time() + self.total_time += time_end - time_begin + log_every_sec( + logging.INFO, 5, + '[Actor {}] sent {} trajectories till now, total trajectory send speed is {} traj/s'.format( + task.router.node_id, + self.traj_num, + self.traj_num / self.total_time, + ) + ) + + gc.collect() + + if ctx.job_finish is True: + job.result = [] + for idx in main_player_idx: + for e in ctx.episode_info[idx]: + job.result.append(e['result']) + task.emit(EventEnum.ACTOR_FINISH_JOB, job) + ctx.episode_info = [[] for _ in range(self.agent_num)] + logging.info('[Actor {}] job finish, send job\n'.format(task.router.node_id)) + break + + self.total_episode_num += ctx.env_episode + logging.info( + '[Actor {}] finish {} episodes till now, speed is {} episode/s'.format( + task.router.node_id, self.total_episode_num, self.total_episode_num / self.total_time + ) + ) + logging.info( + '[Actor {}] sent {} trajectories till now, the episode trajectory speed is {} traj/episode'.format( + task.router.node_id, self.traj_num, self.traj_num / self.total_episode_num + ) + ) + + +#TODO: EpisodeLeagueActor diff --git a/ding/framework/middleware/league_coordinator.py b/ding/framework/middleware/league_coordinator.py new file mode 100644 index 0000000000..209de2f267 --- /dev/null +++ b/ding/framework/middleware/league_coordinator.py @@ -0,0 +1,88 @@ +from collections import defaultdict +from time import sleep, time +from threading import Lock +from dataclasses import dataclass +from typing import TYPE_CHECKING +from ding.framework import task, EventEnum +from ditk import logging + +from ding.utils.sparse_logging import log_every_sec + +if TYPE_CHECKING: + from easydict import EasyDict + from ding.framework import Task, Context + from ding.league.v2 import BaseLeague + from ding.league.player import PlayerMeta + from ding.league.v2.base_league import Job + + +class LeagueCoordinator: + + def __init__(self, cfg: "EasyDict", league: "BaseLeague") -> None: + self.league = league + self._lock = Lock() + self._total_send_jobs = 0 + self._total_recv_jobs = 0 + self._eval_frequency = 10 + self._running_jobs = dict() + self._last_collect_time = None + self._total_collect_time = None + + task.on(EventEnum.ACTOR_GREETING, self._on_actor_greeting) + task.on(EventEnum.LEARNER_SEND_META, self._on_learner_meta) + task.on(EventEnum.ACTOR_FINISH_JOB, self._on_actor_job) + + def _on_actor_greeting(self, actor_id): + logging.info("[Coordinator {}] recieve actor {} greeting".format(task.router.node_id, actor_id)) + if self._last_collect_time is None: + self._last_collect_time = time() + if self._total_collect_time is None: + self._total_collect_time = 0 + with self._lock: + player_num = len(self.league.active_players_ids) + player_id = self.league.active_players_ids[self._total_send_jobs % player_num] + job = self.league.get_job_info(player_id) + job.job_no = self._total_send_jobs + self._total_send_jobs += 1 + if job.job_no > 0 and job.job_no % self._eval_frequency == 0: + job.is_eval = True + job.actor_id = actor_id + self._running_jobs["actor_{}".format(actor_id)] = job + task.emit(EventEnum.COORDINATOR_DISPATCH_ACTOR_JOB.format(actor_id=actor_id), job) + + def _on_learner_meta(self, player_meta: "PlayerMeta"): + log_every_sec( + logging.INFO, 5, + '[Coordinator {}] recieve learner meta from player {}'.format(task.router.node_id, player_meta.player_id) + ) + self.league.update_active_player(player_meta) + self.league.create_historical_player(player_meta) + + def _on_actor_job(self, job: "Job"): + if self._last_collect_time is None: + self._last_collect_time = time() + if self._total_collect_time is None: + self._total_collect_time = 0 + + self._total_recv_jobs += 1 + old_time = self._last_collect_time + self._last_collect_time = time() + self._total_collect_time += self._last_collect_time - old_time + logging.info( + "[Coordinator {}] recieve finished job of player {}, "\ + "recieve {} jobs in total, collect job speed is {} s/job" + .format( + task.router.node_id, job.launch_player, self._total_recv_jobs, + self._total_collect_time / self._total_recv_jobs + ) + ) + self.league.update_payoff(job) + + def __del__(self): + logging.info("[Coordinator {}] all tasks finished, coordinator closed".format(task.router.node_id)) + + def __call__(self, ctx: "Context") -> None: + sleep(1) + log_every_sec( + logging.INFO, 600, "[Coordinator {}] running jobs {}".format(task.router.node_id, self._running_jobs) + ) diff --git a/ding/framework/middleware/league_learner_communicator.py b/ding/framework/middleware/league_learner_communicator.py new file mode 100644 index 0000000000..1e437bfd74 --- /dev/null +++ b/ding/framework/middleware/league_learner_communicator.py @@ -0,0 +1,70 @@ +from ditk import logging +import os +from dataclasses import dataclass +from collections import deque +from time import sleep +from typing import TYPE_CHECKING + +from ding.framework import task, EventEnum +from ding.framework.storage import FileStorage +from ding.league.player import PlayerMeta +from ding.utils.sparse_logging import log_every_sec + +if TYPE_CHECKING: + from ding.policy import Policy + from ding.framework import BattleContext + from ding.framework.middleware.league_actor import ActorData + from ding.league import ActivePlayer + + +@dataclass +class LearnerModel: + player_id: str + state_dict: dict + train_iter: int = 0 + + +class LeagueLearnerCommunicator: + + def __init__(self, cfg: dict, policy: "Policy", player: "ActivePlayer") -> None: + self.cfg = cfg + self._cache = deque(maxlen=20) + self.player = player + self.player_id = player.player_id + self.policy = policy + self.prefix = '{}/ckpt'.format(cfg.exp_name) + if not os.path.exists(self.prefix): + os.makedirs(self.prefix) + task.on(EventEnum.ACTOR_SEND_DATA.format(player=self.player_id), self._push_data) + + def _push_data(self, data: "ActorData"): + log_every_sec( + logging.INFO, 5, + "[Learner {}] receive data of player {} from actor! \n".format(task.router.node_id, self.player_id) + ) + for env_trajectories in data.train_data: + for traj in env_trajectories.trajectories: + self._cache.append(traj) + + def __call__(self, ctx: "BattleContext"): + ctx.trajectories = list(self._cache) + self._cache.clear() + sleep(0.0001) + yield + log_every_sec(logging.INFO, 20, "[Learner {}] ctx.train_iter {}".format(task.router.node_id, ctx.train_iter)) + self.player.total_agent_step = ctx.train_iter + if self.player.is_trained_enough(): + logging.info('{1} [Learner {0}] trained enough! {1} \n\n'.format(task.router.node_id, "-" * 40)) + storage = FileStorage( + path=os.path.join(self.prefix, "{}_{}_ckpt.pth".format(self.player_id, ctx.train_iter)) + ) + storage.save(self.policy.state_dict()) + task.emit( + EventEnum.LEARNER_SEND_META, + PlayerMeta(player_id=self.player_id, checkpoint=storage, total_agent_step=ctx.train_iter) + ) + + learner_model = LearnerModel( + player_id=self.player_id, state_dict=self.policy.state_dict(), train_iter=ctx.train_iter + ) + task.emit(EventEnum.LEARNER_SEND_MODEL, learner_model) diff --git a/ding/framework/middleware/tests/__init__.py b/ding/framework/middleware/tests/__init__.py index 5bb84e7fe2..8e6f536af1 100644 --- a/ding/framework/middleware/tests/__init__.py +++ b/ding/framework/middleware/tests/__init__.py @@ -1 +1 @@ -from .mock_for_test import MockEnv, MockPolicy, MockHerRewardModel, CONFIG +from .mock_for_test import * diff --git a/ding/framework/middleware/tests/mock_for_test.py b/ding/framework/middleware/tests/mock_for_test.py index 7fa6cde7aa..cfb937bf09 100644 --- a/ding/framework/middleware/tests/mock_for_test.py +++ b/ding/framework/middleware/tests/mock_for_test.py @@ -1,10 +1,16 @@ -from typing import Union, Any, List, Callable, Dict, Optional +from typing import TYPE_CHECKING, Union, Any, List, Callable, Dict, Optional from collections import namedtuple +import random import torch import treetensor.numpy as tnp from easydict import EasyDict from unittest.mock import Mock +from ding.torch_utils import to_device +from ding.league.player import PlayerMeta +from ding.league.v2 import BaseLeague, Job +from ding.framework.storage import FileStorage + obs_dim = [2, 2] action_space = 1 env_num = 2 @@ -116,3 +122,211 @@ def __init__(self) -> None: def estimate(self, episode: List[Dict[str, Any]]) -> List[Dict[str, Any]]: return [[episode[0] for _ in range(self.episode_element_size)]] + + +class MockLeague(BaseLeague): + + def __init__(self, cfg) -> None: + super().__init__(cfg) + self.update_payoff_cnt = 0 + self.update_active_player_cnt = 0 + self.create_historical_player_cnt = 0 + self.get_job_info_cnt = 0 + + def update_payoff(self, job): + self.update_payoff_cnt += 1 + + def update_active_player(self, meta): + self.update_active_player_cnt += 1 + + def create_historical_player(self, meta): + self.create_historical_player_cnt += 1 + + def get_job_info(self, player_id): + self.get_job_info_cnt += 1 + other_players = [i for i in self.active_players_ids if i != player_id] + another_palyer = random.choice(other_players) + return Job( + launch_player=player_id, + players=[ + PlayerMeta(player_id=player_id, checkpoint=FileStorage(path=None), total_agent_step=0), + PlayerMeta(player_id=another_palyer, checkpoint=FileStorage(path=None), total_agent_step=0) + ] + ) + + +class MockLogger(): + + def add_scalar(*args): + pass + + def close(*args): + pass + + def flush(*args): + pass + + +league_cfg = EasyDict( + { + 'env': { + 'manager': { + 'episode_num': 100000, + 'max_retry': 1000, + 'retry_type': 'renew', + 'auto_reset': True, + 'step_timeout': None, + 'reset_timeout': None, + 'retry_waiting_time': 0.1, + 'cfg_type': 'BaseEnvManagerDict', + 'shared_memory': False, + 'return_original_data': True + }, + 'collector_env_num': 1, + 'evaluator_env_num': 1, + 'n_evaluator_episode': 100, + 'env_type': 'prisoner_dilemma', + 'stop_value': [-10.1, -5.05] + }, + 'policy': { + 'model': { + 'obs_shape': 2, + 'action_shape': 2, + 'action_space': 'discrete', + 'encoder_hidden_size_list': [32, 32], + 'critic_head_hidden_size': 32, + 'actor_head_hidden_size': 32, + 'share_encoder': False + }, + 'learn': { + 'learner': { + 'train_iterations': 1000000000, + 'dataloader': { + 'num_workers': 0 + }, + 'log_policy': False, + 'hook': { + 'load_ckpt_before_run': '', + 'log_show_after_iter': 100, + 'save_ckpt_after_iter': 10000, + 'save_ckpt_after_run': True + }, + 'cfg_type': 'BaseLearnerDict' + }, + 'multi_gpu': False, + 'epoch_per_collect': 10, + 'batch_size': 16, + 'learning_rate': 1e-05, + 'value_weight': 0.5, + 'entropy_weight': 0.0, + 'clip_ratio': 0.2, + 'adv_norm': True, + 'value_norm': True, + 'ppo_param_init': True, + 'grad_clip_type': 'clip_norm', + 'grad_clip_value': 0.5, + 'ignore_done': False, + 'update_per_collect': 3, + 'scheduler': { + 'schedule_flag': False, + 'schedule_mode': 'reduce', + 'factor': 0.005, + 'change_range': [0, 1], + 'threshold': 0.5, + 'patience': 50 + } + }, + 'collect': { + 'collector': { + 'deepcopy_obs': False, + 'transform_obs': False, + 'collect_print_freq': 100, + 'get_train_sample': True, + 'cfg_type': 'BattleEpisodeSerialCollectorDict' + }, + 'discount_factor': 1.0, + 'gae_lambda': 1.0, + 'n_episode': 1, + 'n_rollout_samples': 64, + 'n_sample': 64, + 'unroll_len': 1 + }, + 'eval': { + 'evaluator': { + 'eval_freq': 50, + 'cfg_type': 'BattleInteractionSerialEvaluatorDict', + 'stop_value': [-10.1, -5.05], + 'n_episode': 100 + } + }, + 'other': { + 'replay_buffer': { + 'type': 'naive', + 'replay_buffer_size': 10000, + 'deepcopy': False, + 'enable_track_used_data': False, + 'periodic_thruput_seconds': 60, + 'cfg_type': 'NaiveReplayBufferDict' + }, + 'league': { + 'player_category': ['default'], + 'path_policy': 'league_demo/ckpt', + 'active_players': { + 'main_player': 2 + }, + 'main_player': { + 'one_phase_step': 10, # 20 + 'branch_probs': { + 'pfsp': 0.0, + 'sp': 1.0 + }, + 'strong_win_rate': 0.7 + }, + 'main_exploiter': { + 'one_phase_step': 200, + 'branch_probs': { + 'main_players': 1.0 + }, + 'strong_win_rate': 0.7, + 'min_valid_win_rate': 0.3 + }, + 'league_exploiter': { + 'one_phase_step': 200, + 'branch_probs': { + 'pfsp': 1.0 + }, + 'strong_win_rate': 0.7, + 'mutate_prob': 0.5 + }, + 'use_pretrain': False, + 'use_pretrain_init_historical': False, + 'payoff': { + 'type': 'battle', + 'decay': 0.99, + 'min_win_rate_games': 8 + }, + 'metric': { + 'mu': 0, + 'sigma': 8.333333333333334, + 'beta': 4.166666666666667, + 'tau': 0.0, + 'draw_probability': 0.02 + } + } + }, + 'type': 'ppo', + 'cuda': False, + 'on_policy': True, + 'priority': False, + 'priority_IS_weight': False, + 'recompute_adv': True, + 'action_space': 'discrete', + 'nstep_return': False, + 'multi_agent': False, + 'transition_with_policy_data': True, + 'cfg_type': 'PPOPolicyDict' + }, + 'exp_name': 'league_demo', + 'seed': 0 + } +) diff --git a/ding/framework/middleware/tests/test_collector.py b/ding/framework/middleware/tests/test_collector.py index 16c577ae2b..bfb1eaf9e0 100644 --- a/ding/framework/middleware/tests/test_collector.py +++ b/ding/framework/middleware/tests/test_collector.py @@ -6,6 +6,8 @@ from ding.framework.middleware import TransitionList, inferencer, rolloutor from ding.framework.middleware import StepCollector, EpisodeCollector from ding.framework.middleware.tests import MockPolicy, MockEnv, CONFIG +from ding.framework.middleware import BattleTransitionList +from easydict import EasyDict @pytest.mark.unittest @@ -84,3 +86,107 @@ def test_episode_collector(): collector = EpisodeCollector(cfg, policy, env, random_collect_size=8) collector(ctx) assert len(ctx.episodes) == 16 + + +def test_battle_transition(): + env_num = 2 + unroll_len = 32 + transition_list = BattleTransitionList(env_num, unroll_len) + len_env_0 = 48 + len_env_1 = 72 + + for i in range(len_env_0): + timestep = EasyDict({'obs': i, 'done': False}) + transition_list.append(env_id=0, transition=timestep) + + transition_list.append(env_id=0, transition=EasyDict({'obs': len_env_0, 'done': True})) + + for i in range(len_env_1): + timestep = EasyDict({'obs': i, 'done': False}) + transition_list.append(env_id=1, transition=timestep) + + transition_list.append(env_id=1, transition=EasyDict({'obs': len_env_1, 'done': True})) + + len_env_0_2 = 12 + len_env_1_2 = 72 + + for i in range(len_env_0_2): + timestep = EasyDict({'obs': i, 'done': False}) + transition_list.append(env_id=0, transition=timestep) + + transition_list.append(env_id=0, transition=EasyDict({'obs': len_env_0_2, 'done': True})) + + for i in range(len_env_1_2): + timestep = EasyDict({'obs': i, 'done': False}) + transition_list.append(env_id=1, transition=timestep) + + transition_list_2 = copy.deepcopy(transition_list) + + env_0_result = transition_list.get_env_trajectories(env_id=0) + env_1_result = transition_list.get_env_trajectories(env_id=1) + + # print(env_0_result) + # print(env_1_result) + + assert len(env_0_result) == 3 + assert len(env_1_result) == 4 + + for trajectory in env_0_result: + assert len(trajectory) == unroll_len + for trajectory in env_1_result: + assert len(trajectory) == unroll_len + #env_0 + i = 0 + trajectory = env_0_result[0] + for transition in trajectory: + assert transition.obs == i + i += 1 + + trajectory = env_0_result[1] + i = len_env_0 - unroll_len + 1 + for transition in trajectory: + assert transition.obs == i + i += 1 + + trajectory = env_0_result[2] + test_number = 0 + for i, transition in enumerate(trajectory): + if i < unroll_len - len_env_0_2 - 1: + assert transition.obs == 0 + else: + assert transition.obs == test_number + test_number += 1 + + #env_1 + i = 0 + for trajectory in env_1_result[:2]: + assert len(trajectory) == unroll_len + for transition in trajectory: + assert transition.obs == i + i += 1 + + trajectory = env_1_result[2] + assert len(trajectory) == unroll_len + + i = len_env_1 - unroll_len + 1 + for transition in trajectory: + assert transition.obs == i + i += 1 + + trajectory = env_1_result[3] + assert len(trajectory) == unroll_len + i = 0 + for transition in trajectory: + assert transition.obs == i + i += 1 + + transition_list_2.clear_newest_episode(env_id=0, before_append=True) + transition_list_2.clear_newest_episode(env_id=1, before_append=True) + assert len(transition_list_2._transitions[0]) == 2 + assert len(transition_list_2._transitions[1]) == 1 + + +@pytest.mark.unittest +def test_battle_transition_list(): + with task.start(): + test_battle_transition() diff --git a/ding/framework/middleware/tests/test_handle_step_exception.py b/ding/framework/middleware/tests/test_handle_step_exception.py new file mode 100644 index 0000000000..240f280be4 --- /dev/null +++ b/ding/framework/middleware/tests/test_handle_step_exception.py @@ -0,0 +1,50 @@ +from ding.framework.context import BattleContext +from ding.framework.middleware.functional.collector import BattleTransitionList +from ding.framework.middleware.functional import battle_rolloutor +import pytest +from unittest.mock import Mock +from ding.envs import BaseEnvTimestep +from easydict import EasyDict + + +class MockEnvManager: + + def __init__(self) -> None: + self.ready_obs = [[[]]] + + def step(self, actions): + timesteps = {} + for env_id in actions.keys(): + timesteps[env_id] = BaseEnvTimestep(obs=[1], reward=[1, 1], done=False, info={'abnormal': True}) + return timesteps + + +class MockPolicy: + + def __init__(self) -> None: + pass + + def reset(self, data): + pass + + +@pytest.mark.unittest +def test_handle_step_exception(): + ctx = BattleContext() + ctx.total_envstep_count = 10 + ctx.env_step = 20 + transitions_list = [BattleTransitionList(env_num=2, unroll_len=5)] + ctx.current_policies = [MockPolicy()] + for _ in range(5): + transitions_list[0].append(env_id=0, transition=BaseEnvTimestep(obs=[1], reward=[1, 1], done=False, info={})) + transitions_list[0].append(env_id=1, transition=BaseEnvTimestep(obs=[1], reward=[1, 1], done=False, info={})) + + ctx.actions = {0: {}} + ctx.obs = {0: {0: {}}} + rolloutor = battle_rolloutor( + cfg=EasyDict(), env=MockEnvManager(), transitions_list=transitions_list, model_info_dict=None + ) + rolloutor(ctx) + + assert len(transitions_list[0]._transitions[0]) == 0 + assert len(transitions_list[0]._transitions[1]) == 1 diff --git a/ding/framework/middleware/tests/test_league_actor.py b/ding/framework/middleware/tests/test_league_actor.py new file mode 100644 index 0000000000..3cd005530c --- /dev/null +++ b/ding/framework/middleware/tests/test_league_actor.py @@ -0,0 +1,108 @@ +from time import sleep +import pytest +from copy import deepcopy +from ding.envs import BaseEnvManager +from ding.framework.context import BattleContext +from ding.framework.middleware.league_learner_communicator import LearnerModel +from ding.framework.middleware.tests.mock_for_test import league_cfg +from ding.framework.middleware import StepLeagueActor +from ding.framework.middleware.functional import ActorData +from ding.league.player import PlayerMeta +from ding.framework.storage import FileStorage + +from ding.framework.task import task, Parallel +from ding.league.v2.base_league import Job +from ding.model import VAC +from ding.policy.ppo import PPOPolicy +from dizoo.league_demo.game_env import GameEnv + +from ding.framework import EventEnum + + +def prepare_test(): + global league_cfg + cfg = deepcopy(league_cfg) + + def env_fn(): + env = BaseEnvManager( + env_fn=[lambda: GameEnv(cfg.env.env_type) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager + ) + env.seed(cfg.seed) + return env + + def policy_fn(): + model = VAC(**cfg.policy.model) + policy = PPOPolicy(cfg.policy, model=model) + return policy + + return cfg, env_fn, policy_fn + + +def _main(): + cfg, env_fn, policy_fn = prepare_test() + policy = policy_fn() + job = Job( + launch_player='main_player_default_0', + players=[ + PlayerMeta(player_id='main_player_default_0', checkpoint=FileStorage(path=None), total_agent_step=0), + PlayerMeta(player_id='main_player_default_1', checkpoint=FileStorage(path=None), total_agent_step=0) + ] + ) + ACTOR_ID = 0 + + with task.start(async_mode=True, ctx=BattleContext()): + league_actor = StepLeagueActor(cfg, env_fn, policy_fn) + + def test_actor(): + testcases = { + "on_actor_greeting": False, + "on_actor_job": False, + "on_actor_data": False, + } + + def on_actor_greeting(actor_id): + assert actor_id == ACTOR_ID + testcases["on_actor_greeting"] = True + task.emit(EventEnum.COORDINATOR_DISPATCH_ACTOR_JOB.format(actor_id=ACTOR_ID), job) + + def on_actor_job(job_: Job): + assert job_.launch_player == job.launch_player + testcases["on_actor_job"] = True + + def on_actor_data(actor_data): + assert isinstance(actor_data, ActorData) + testcases["on_actor_data"] = True + + task.on(EventEnum.ACTOR_GREETING, on_actor_greeting) + task.on(EventEnum.ACTOR_FINISH_JOB, on_actor_job) + task.on(EventEnum.ACTOR_SEND_DATA.format(player=job.launch_player), on_actor_data) + + def _test_actor(ctx): + sleep(0.3) + + task.emit( + EventEnum.LEARNER_SEND_MODEL, + LearnerModel( + player_id='main_player_default_0', state_dict=policy.learn_mode.state_dict(), train_iter=0 + ) + ) + sleep(10) + try: + print(testcases) + assert all(testcases.values()) + finally: + task.finish = True + + return _test_actor + + if task.router.node_id == ACTOR_ID: + task.use(league_actor) + elif task.router.node_id == 1: + task.use(test_actor()) + + task.run(max_step=5) + + +@pytest.mark.unittest +def test_league_actor(): + Parallel.runner(n_parallel_workers=2, protocol="tcp", topology="mesh")(_main) diff --git a/ding/framework/middleware/tests/test_league_coordinator.py b/ding/framework/middleware/tests/test_league_coordinator.py new file mode 100644 index 0000000000..fee2713850 --- /dev/null +++ b/ding/framework/middleware/tests/test_league_coordinator.py @@ -0,0 +1,74 @@ +import pytest +import time +from unittest.mock import patch +from ding.framework import task, Parallel +from ding.framework.middleware import LeagueCoordinator +from ding.league.v2 import Job +from ding.framework import EventEnum +from ding.league.player import PlayerMeta + + +class MockLeague: + + def __init__(self): + self.active_players_ids = ["player_0", "player_1", "player_2"] + self.update_payoff_cnt = 0 + self.update_active_player_cnt = 0 + self.create_historical_player_cnt = 0 + self.get_job_info_cnt = 0 + + def update_payoff(self, job): + self.update_payoff_cnt += 1 + + def update_active_player(self, meta): + self.update_active_player_cnt += 1 + + def create_historical_player(self, meta): + self.create_historical_player_cnt += 1 + + def get_job_info(self, player_id): + self.get_job_info_cnt += 1 + return Job(launch_player=player_id, players=[]) + + +def _main(): + with task.start(): + if task.router.node_id == 0: + with patch("ding.league.BaseLeague", MockLeague): + league = MockLeague() + coordinator = LeagueCoordinator(None, league) + time.sleep(3) + assert league.update_payoff_cnt == 3 + assert league.update_active_player_cnt == 3 + assert league.create_historical_player_cnt == 3 + assert league.get_job_info_cnt == 3 + elif task.router.node_id == 1: + # test ACTOR_GREETING + res = [] + task.on( + EventEnum.COORDINATOR_DISPATCH_ACTOR_JOB.format(actor_id=task.router.node_id), + lambda job: res.append(job) + ) + for _ in range(3): + task.emit(EventEnum.ACTOR_GREETING, task.router.node_id) + time.sleep(5) + assert task.router.node_id == res[-1].actor_id + elif task.router.node_id == 2: + # test LEARNER_SEND_META + for i in range(3): + player_meta = PlayerMeta(player_id="test_player_{}".format(i), checkpoint=None) + task.emit(EventEnum.LEARNER_SEND_META, player_meta) + time.sleep(3) + elif task.router.node_id == 3: + # test ACTOR_FINISH_JOB + job = Job(-1, task.router.node_id, False) + for _ in range(3): + task.emit(EventEnum.ACTOR_FINISH_JOB, job) + time.sleep(3) + else: + raise Exception("Invalid node id {}".format(task.router.is_active)) + + +@pytest.mark.unittest +def test_coordinator(): + Parallel.runner(n_parallel_workers=4, protocol="tcp", topology="star")(_main) diff --git a/ding/framework/middleware/tests/test_league_learner_communicator.py b/ding/framework/middleware/tests/test_league_learner_communicator.py new file mode 100644 index 0000000000..7207eb3c14 --- /dev/null +++ b/ding/framework/middleware/tests/test_league_learner_communicator.py @@ -0,0 +1,126 @@ +from copy import deepcopy +from dataclasses import dataclass +from time import sleep +import time +import pytest +import logging +from typing import Any +from unittest.mock import patch +from typing import Callable, Optional + +from ding.framework.context import BattleContext +from ding.framework import EventEnum +from ding.framework.task import task, Parallel +from ding.framework.middleware import LeagueLearnerCommunicator, LearnerModel +from ding.framework.middleware.functional.actor_data import ActorData, ActorDataMeta, ActorEnvTrajectories +from ding.framework.middleware.tests.mock_for_test import league_cfg + +from ding.model import VAC +from ding.policy.ppo import PPOPolicy + +PLAYER_ID = "test_player" + + +def prepare_test(): + global league_cfg + cfg = deepcopy(league_cfg) + + def policy_fn(): + model = VAC(**cfg.policy.model) + policy = PPOPolicy(cfg.policy, model=model) + return policy + + return cfg, policy_fn + + +@dataclass +class TestActorData: + env_step: int + train_data: Any + + +class MockFileStorage: + + def __init__(self, path: str) -> None: + self.path = path + + def save(self, data: Any) -> bool: + assert isinstance(data, dict) + + +class MockPlayer: + + def __init__(self) -> None: + self.player_id = PLAYER_ID + self.total_agent_step = 0 + + def is_trained_enough(self) -> bool: + return True + + +def coordinator_mocker(): + + test_cases = {"on_learner_meta": False} + + def on_learner_meta(player_meta): + assert player_meta.player_id == PLAYER_ID + test_cases["on_learner_meta"] = True + + task.on(EventEnum.LEARNER_SEND_META.format(player=PLAYER_ID), on_learner_meta) + + def _coordinator_mocker(ctx): + sleep(0.8) + assert all(test_cases.values()) + + return _coordinator_mocker + + +def actor_mocker(): + + test_cases = {"on_learner_model": False} + + def on_learner_model(learner_model): + assert isinstance(learner_model, LearnerModel) + assert learner_model.player_id == PLAYER_ID + test_cases["on_learner_model"] = True + + task.on(EventEnum.LEARNER_SEND_MODEL.format(player=PLAYER_ID), on_learner_model) + + def _actor_mocker(ctx): + sleep(0.2) + player = MockPlayer() + for _ in range(10): + meta = ActorDataMeta(player_total_env_step=0, actor_id=0, send_wall_time=time.time()) + data = [] + actor_data = ActorData(meta=meta, train_data=[ActorEnvTrajectories(env_id=0, trajectories=[data])]) + task.emit(EventEnum.ACTOR_SEND_DATA.format(player=player.player_id), actor_data) + + sleep(0.8) + assert all(test_cases.values()) + + return _actor_mocker + + +def _main(): + logging.getLogger().setLevel(logging.INFO) + cfg, policy_fn = prepare_test() + + with task.start(async_mode=False, ctx=BattleContext()): + if task.router.node_id == 0: + task.use(coordinator_mocker()) + elif task.router.node_id <= 1: + task.use(actor_mocker()) + else: + player = MockPlayer() + policy = policy_fn() + with patch("ding.framework.storage.FileStorage", MockFileStorage): + learner_communicator = LeagueLearnerCommunicator(cfg, policy.learn_mode, player) + sleep(0.5) + task.use(learner_communicator) + sleep(0.1) + task.run(max_step=5) + + +@pytest.mark.unittest +def test_league_learner_communicator(): + Parallel.runner(n_parallel_workers=3, protocol="tcp", topology="mesh")(_main) diff --git a/ding/framework/parallel.py b/ding/framework/parallel.py index 70134f6584..20a3c750ca 100644 --- a/ding/framework/parallel.py +++ b/ding/framework/parallel.py @@ -5,6 +5,7 @@ import traceback from mpire.pool import WorkerPool import pickle +import io from ditk import logging import tempfile import socket @@ -15,11 +16,40 @@ from ding.utils.design_helper import SingletonMetaclass from ding.framework.message_queue import * from ding.utils.registry_factory import MQ_REGISTRY +import torch # Avoid ipc address conflict, random should always use random seed random = random.Random() +class CpuUnpickler(pickle.Unpickler): + + def find_class(self, module, name): + if module == 'torch.storage' and name == '_load_from_bytes': + return lambda b: torch.load(io.BytesIO(b), map_location='cpu') + else: + return super().find_class(module, name) + + +def cpu_loads(x): + bs = io.BytesIO(x) + unpickler = CpuUnpickler(bs) + return unpickler.load() + + +def my_pickle_loads(msg): + """ + Overview: + This method allows you to recieve gpu tensors from gpu bug freely, if you are in an only-cpu node. + refrence: https://github.com/pytorch/pytorch/issues/16797 + """ + if not torch.cuda.is_available(): + payload = cpu_loads(msg) + else: + payload = pickle.loads(msg) + return payload + + class Parallel(metaclass=SingletonMetaclass): def __init__(self) -> None: @@ -339,7 +369,7 @@ def _handle_message(self, topic: str, msg: bytes) -> None: logging.debug("Event {} was not listened in parallel {}".format(event, self.node_id)) return try: - payload = pickle.loads(msg) + payload = my_pickle_loads(msg) except Exception as e: logging.error("Error when unpacking message on node {}, msg: {}".format(self.node_id, e)) return diff --git a/ding/framework/storage/__init__.py b/ding/framework/storage/__init__.py new file mode 100644 index 0000000000..6d8c388953 --- /dev/null +++ b/ding/framework/storage/__init__.py @@ -0,0 +1,2 @@ +from .storage import Storage +from .file import FileStorage diff --git a/ding/framework/storage/file.py b/ding/framework/storage/file.py new file mode 100644 index 0000000000..18bcc2c39d --- /dev/null +++ b/ding/framework/storage/file.py @@ -0,0 +1,12 @@ +from typing import Any +from ding.framework.storage import Storage +from ding.utils import read_file, save_file + + +class FileStorage(Storage): + + def save(self, data: Any) -> None: + save_file(self.path, data) + + def load(self) -> Any: + return read_file(self.path) diff --git a/ding/framework/storage/storage.py b/ding/framework/storage/storage.py new file mode 100644 index 0000000000..6ba9e81c9d --- /dev/null +++ b/ding/framework/storage/storage.py @@ -0,0 +1,16 @@ +from abc import abstractmethod +from typing import Any + + +class Storage: + + def __init__(self, path: str) -> None: + self.path = path + + @abstractmethod + def save(self, data: Any) -> None: + raise NotImplementedError + + @abstractmethod + def load(self) -> Any: + raise NotImplementedError diff --git a/ding/framework/storage/tests/test_storage.py b/ding/framework/storage/tests/test_storage.py new file mode 100644 index 0000000000..53eebe3fa2 --- /dev/null +++ b/ding/framework/storage/tests/test_storage.py @@ -0,0 +1,18 @@ +import tempfile +import pytest +import os +from os import path +from ding.framework.storage import FileStorage + + +@pytest.mark.unittest +def test_file_storage(): + path_ = path.join(tempfile.gettempdir(), "test_storage.txt") + try: + storage = FileStorage(path=path_) + storage.save("test") + content = storage.load() + assert content == "test" + finally: + if path.exists(path_): + os.remove(path_) diff --git a/ding/league/algorithm.py b/ding/league/algorithm.py index 09cb23b352..b7de39970a 100644 --- a/ding/league/algorithm.py +++ b/ding/league/algorithm.py @@ -22,7 +22,7 @@ def pfsp(win_rates: np.ndarray, weighting: str) -> np.ndarray: raise KeyError("invalid weighting arg: {} in pfsp".format(weighting)) assert isinstance(win_rates, np.ndarray) - assert win_rates.shape[0] >= 1, win_rates.shape + assert win_rates.shape[0] >= 1, "win rate is {}".format(win_rates) # all zero win rates case, return uniform selection prob if win_rates.sum() < 1e-8: return np.full_like(win_rates, 1.0 / len(win_rates)) diff --git a/ding/league/player.py b/ding/league/player.py index e253c0bdad..7e2e031be2 100644 --- a/ding/league/player.py +++ b/ding/league/player.py @@ -1,3 +1,5 @@ +from dataclasses import dataclass +import traceback from typing import Callable, Optional, List from collections import namedtuple import numpy as np @@ -5,6 +7,20 @@ from ding.utils import import_module, PLAYER_REGISTRY from .algorithm import pfsp +from ding.framework.storage import FileStorage +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ding.league.shared_payoff import BattleSharedPayoff + from ding.league.metric import LeagueMetricEnv + from ding.league.metric import PlayerRating + from ding.framework.storage import Storage + + +@dataclass +class PlayerMeta: + player_id: str + checkpoint: "Storage" + total_agent_step: int = 0 class Player: @@ -22,11 +38,11 @@ def __init__( self, cfg: EasyDict, category: str, - init_payoff: 'BattleSharedPayoff', # noqa + init_payoff: 'BattleSharedPayoff', checkpoint_path: str, player_id: str, total_agent_step: int, - rating: 'PlayerRating', # noqa + rating: 'PlayerRating', ) -> None: """ Overview: @@ -46,6 +62,7 @@ def __init__( self._category = category self._payoff = init_payoff self._checkpoint_path = checkpoint_path + self.checkpoint = FileStorage(path=checkpoint_path) assert isinstance(player_id, str) self._player_id = player_id assert isinstance(total_agent_step, int), (total_agent_step, type(total_agent_step)) @@ -57,7 +74,7 @@ def category(self) -> str: return self._category @property - def payoff(self) -> 'BattleSharedPayoff': # noqa + def payoff(self) -> 'BattleSharedPayoff': return self._payoff @property @@ -77,13 +94,17 @@ def total_agent_step(self, step: int) -> None: self._total_agent_step = step @property - def rating(self) -> 'PlayerRating': # noqa + def rating(self) -> 'PlayerRating': return self._rating @rating.setter - def rating(self, _rating: 'PlayerRating') -> None: # noqa + def rating(self, _rating: 'PlayerRating') -> None: self._rating = _rating + @property + def meta(self) -> PlayerMeta: + return PlayerMeta(player_id=self.player_id, checkpoint=self.checkpoint, total_agent_step=self.total_agent_step) + @PLAYER_REGISTRY.register('historical_player') class HistoricalPlayer(Player): @@ -179,7 +200,9 @@ def is_trained_enough(self, select_fn: Optional[Callable] = None) -> bool: else: return False - def snapshot(self, metric_env: 'LeagueMetricEnv') -> HistoricalPlayer: # noqa + def snapshot( + self, metric_env: 'LeagueMetricEnv', checkpoint: Optional["Storage"] = None + ) -> HistoricalPlayer: # noqa """ Overview: Generate a snapshot historical player from the current player, called in league's ``_snapshot``. @@ -192,8 +215,11 @@ def snapshot(self, metric_env: 'LeagueMetricEnv') -> HistoricalPlayer: # noqa This method only generates a historical player object, but without saving the checkpoint, which should be done by league. """ - path = self.checkpoint_path.split('.pth')[0] + '_{}'.format(self._total_agent_step) + '.pth' - return HistoricalPlayer( + if checkpoint: + path = checkpoint.path + else: + path = self.checkpoint_path.split('.pth')[0] + '_{}'.format(self._total_agent_step) + '.pth' + hp = HistoricalPlayer( self._cfg, self.category, self.payoff, @@ -203,6 +229,9 @@ def snapshot(self, metric_env: 'LeagueMetricEnv') -> HistoricalPlayer: # noqa metric_env.create_rating(mu=self.rating.mu), parent_id=self.player_id ) + if checkpoint: + hp.checkpoint = checkpoint + return hp def mutate(self, info: dict) -> Optional[str]: """ diff --git a/ding/league/starcraft_player.py b/ding/league/starcraft_player.py index 81d53e73bd..900fed6149 100644 --- a/ding/league/starcraft_player.py +++ b/ding/league/starcraft_player.py @@ -54,6 +54,17 @@ def _sp_branch(self): p = pfsp(win_rates, weighting='variance') return self._get_opponent(historical, p) + def _sl_branch(self): + """ + Overview: + Select one opponent, whose ckpt is sl_model.pth + """ + historical = self._get_players( + lambda p: isinstance(p, HistoricalPlayer) and p.player_id == 'main_player_default_0_pretrain_historical' + ) + main_opponent = self._get_opponent(historical) + return main_opponent + def _verification_branch(self): """ Overview: diff --git a/ding/league/v2/__init__.py b/ding/league/v2/__init__.py new file mode 100644 index 0000000000..0eb76f750a --- /dev/null +++ b/ding/league/v2/__init__.py @@ -0,0 +1 @@ +from .base_league import BaseLeague, Job diff --git a/ding/league/v2/base_league.py b/ding/league/v2/base_league.py new file mode 100644 index 0000000000..f5ca20eaa6 --- /dev/null +++ b/ding/league/v2/base_league.py @@ -0,0 +1,215 @@ +from dataclasses import dataclass, field +from typing import List +import copy +from easydict import EasyDict + +from ding.league.player import ActivePlayer, HistoricalPlayer, create_player +from ding.league.shared_payoff import create_payoff +from ding.utils import deep_merge_dicts +from ding.league.metric import LeagueMetricEnv +from ding.framework.storage import Storage +from typing import TYPE_CHECKING +from ditk import logging +if TYPE_CHECKING: + from ding.league import Player, PlayerMeta + + +@dataclass +class Job: + launch_player: str + players: List["PlayerMeta"] + result: list = field(default_factory=list) + job_no: int = 0 # Serial number of job, not required + train_iter: int = None + is_eval: bool = False + + +class BaseLeague: + """ + Overview: + League, proposed by Google Deepmind AlphaStar. Can manage multiple players in one league. + Interface: + get_job_info, create_historical_player, update_active_player, update_payoff + + .. note:: + In ``__init__`` method, league would also initialized players as well(in ``_init_players`` method). + """ + + @classmethod + def default_config(cls: type) -> EasyDict: + cfg = EasyDict(copy.deepcopy(cls.config)) + cfg.cfg_type = cls.__name__ + 'Dict' + return cfg + + config = dict( + league_type='baseV2', + import_names=["ding.league.v2.base_league"], + # ---player---- + # "player_category" is just a name. Depends on the env. + # For example, in StarCraft, this can be ['zerg', 'terran', 'protoss']. + player_category=['default'], + # Support different types of active players for solo and battle league. + # For solo league, supports ['solo_active_player']. + # For battle league, supports ['battle_active_player', 'main_player', 'main_exploiter', 'league_exploiter']. + # active_players=dict(), + # "use_pretrain" means whether to use pretrain model to initialize active player. + use_pretrain=False, + # "use_pretrain_init_historical" means whether to use pretrain model to initialize historical player. + # "pretrain_checkpoint_path" is the pretrain checkpoint path used in "use_pretrain" and + # "use_pretrain_init_historical". If both are False, "pretrain_checkpoint_path" can be omitted as well. + # Otherwise, "pretrain_checkpoint_path" should list paths of all player categories. + use_pretrain_init_historical=False, + pretrain_checkpoint_path=dict(default='default_cate_pretrain.pth', ), + # ---payoff--- + payoff=dict( + # Supports ['battle'] + type='battle', + decay=0.99, + min_win_rate_games=8, + ), + metric=dict( + mu=0, + sigma=25 / 3, + beta=25 / 3 / 2, + tau=0.0, + draw_probability=0.02, + ), + ) + + def __init__(self, cfg: EasyDict) -> None: + """ + Overview: + Initialization method. + Arguments: + - cfg (:obj:`EasyDict`): League config. + """ + self.cfg = deep_merge_dicts(self.default_config(), cfg) + self.active_players: List["ActivePlayer"] = [] + self.historical_players: List["HistoricalPlayer"] = [] + self.payoff = create_payoff(self.cfg.payoff) + metric_cfg = self.cfg.metric + self.metric_env = LeagueMetricEnv(metric_cfg.mu, metric_cfg.sigma, metric_cfg.tau, metric_cfg.draw_probability) + self._init_players() + + def _init_players(self) -> None: + """ + Overview: + Initialize players (active & historical) in the league. + """ + # Add different types of active players for each player category, according to ``cfg.active_players``. + for cate in self.cfg.player_category: # Player's category (Depends on the env) + for k, n in self.cfg.active_players.items(): # Active player's type + for i in range(n): # This type's active player number + name = '{}_{}_{}'.format(k, cate, i) + player = create_player( + self.cfg, k, self.cfg[k], cate, self.payoff, None, name, 0, self.metric_env.create_rating() + ) + self.active_players.append(player) + self.payoff.add_player(player) + + # Add pretrain player as the initial HistoricalPlayer for each player category. + if self.cfg.use_pretrain_init_historical: + for cate in self.cfg.player_category: + main_player_name = [k for k in self.cfg.keys() if 'main_player' in k] + assert len(main_player_name) == 1, main_player_name + main_player_name = main_player_name[0] + name = '{}_{}_0_pretrain_historical'.format(main_player_name, cate) + parent_name = '{}_{}_0'.format(main_player_name, cate) + hp = HistoricalPlayer( + self.cfg.get(main_player_name), + cate, + self.payoff, + self.cfg.pretrain_checkpoint_path[cate], + name, + 0, + self.metric_env.create_rating(), + parent_id=parent_name + ) + self.historical_players.append(hp) + self.payoff.add_player(hp) + + # Save active players' ``player_id`` + self.active_players_ids = [p.player_id for p in self.active_players] + # Validate active players are unique by ``player_id``. + assert len(self.active_players_ids) == len(set(self.active_players_ids)) + + def get_job_info(self, player_id: str = None) -> Job: + """ + Overview: + Get info dict of the job which is to be launched to an active player. + Arguments: + - player_id (:obj:`str`): The active player's id. + Returns: + - job_info (:obj:`dict`): Job info. + ReturnsKeys: + - necessary: ``launch_player`` (the active player) + """ + if player_id is None: + player_id = self.active_players_ids[0] + idx = self.active_players_ids.index(player_id) + player = self.active_players[idx] + player_job_info = player.get_job() + opponent_player = player_job_info["opponent"] + job = Job(launch_player=player_id, players=[player.meta, opponent_player.meta]) + return job + + def create_historical_player(self, player_meta: "PlayerMeta", force: bool = False) -> None: + """ + Overview: + Judge whether a player is trained enough for snapshot. If yes, call player's ``snapshot``, create a + historical player(prepare the checkpoint and add it to the shared payoff), then mutate it, and return True. + Otherwise, return False. + Arguments: + - player_id (:obj:`ActivePlayer`): The active player's id. + """ + idx = self.active_players_ids.index(player_meta.player_id) + player: "ActivePlayer" = self.active_players[idx] + if force or (player_meta.checkpoint and player.is_trained_enough()): + # Snapshot + hp = player.snapshot(self.metric_env, player_meta.checkpoint) + self.historical_players.append(hp) + self.payoff.add_player(hp) + + def update_active_player(self, player_meta: "PlayerMeta") -> None: + """ + Overview: + Update an active player's info. + Arguments: + - player_id (:obj:`str`): Player id. + - train_iter (:obj:`int`): Train iteration. + """ + idx = self.active_players_ids.index(player_meta.player_id) + player = self.active_players[idx] + if isinstance(player, ActivePlayer): + player.total_agent_step = player_meta.total_agent_step + + def update_payoff(self, job: Job) -> None: + """ + Overview: + Finish current job. Update shared payoff to record the game results. + Arguments: + - job_info (:obj:`dict`): A dict containing job result information. + """ + job_info = { + "launch_player": job.launch_player, + "player_id": list(map(lambda p: p.player_id, job.players)), + "result": job.result + } + self.payoff.update(job_info) + + logging.info("show the current payoff {}".format(self.payoff._data)) + # Update player rating + home_id, away_id = job_info['player_id'] + home_player, away_player = self.get_player_by_id(home_id), self.get_player_by_id(away_id) + job_info_result = job_info['result'] + if isinstance(job_info_result[0], list): + job_info_result = sum(job_info_result, []) + home_player.rating, away_player.rating = self.metric_env.rate_1vs1( + home_player.rating, away_player.rating, result=job_info_result + ) + + def get_player_by_id(self, player_id: str) -> 'Player': + if 'historical' in player_id: + return [p for p in self.historical_players if p.player_id == player_id][0] + else: + return [p for p in self.active_players if p.player_id == player_id][0] diff --git a/ding/policy/base_policy.py b/ding/policy/base_policy.py index a30b080015..52491ed5cb 100644 --- a/ding/policy/base_policy.py +++ b/ding/policy/base_policy.py @@ -69,7 +69,7 @@ def __init__( assert set(self._enable_field).issubset(self.total_field), self._enable_field if len(set(self._enable_field).intersection(set(['learn', 'collect', 'eval']))) > 0: - model = self._create_model(cfg, model) + model = self._create_model(cfg, model, enable_field) self._cuda = cfg.cuda and torch.cuda.is_available() # now only support multi-gpu for only enable learn mode if len(set(self._enable_field).intersection(set(['learn']))) > 0: @@ -117,7 +117,12 @@ def hook(*ignore): grad_acc = p_tmp.grad_fn.next_functions[0][0] grad_acc.register_hook(make_hook(name, p)) - def _create_model(self, cfg: dict, model: Optional[torch.nn.Module] = None) -> torch.nn.Module: + def _create_model( + self, + cfg: dict, + model: Optional[torch.nn.Module] = None, + enable_field: Optional[List[str]] = None + ) -> torch.nn.Module: if model is None: model_cfg = cfg.model if 'type' not in model_cfg: diff --git a/ding/rl_utils/__init__.py b/ding/rl_utils/__init__.py index d622d19600..673ce2da19 100644 --- a/ding/rl_utils/__init__.py +++ b/ding/rl_utils/__init__.py @@ -12,11 +12,11 @@ fqf_nstep_td_data, fqf_nstep_td_error, fqf_calculate_fraction_loss, evaluate_quantile_at_action, \ q_nstep_sql_td_error, dqfd_nstep_td_error, dqfd_nstep_td_data, q_v_1step_td_error, q_v_1step_td_data,\ dqfd_nstep_td_error_with_rescale, discount_cumsum -from .vtrace import vtrace_loss, compute_importance_weights -from .upgo import upgo_loss +from .upgo import upgo_data, upgo_error from .adder import get_gae, get_gae_with_default_last_value, get_nstep_return_data, get_train_sample from .value_rescale import value_transform, value_inv_transform -from .vtrace import vtrace_data, vtrace_error +from .vtrace import vtrace_data, vtrace_error, vtrace_loss, vtrace_data_with_rho, vtrace_error_with_rho, \ + compute_importance_weights from .beta_function import beta_function_map from .retrace import compute_q_retraces from .acer import acer_policy_error, acer_value_error, acer_trust_region_update diff --git a/ding/rl_utils/tests/test_upgo.py b/ding/rl_utils/tests/test_upgo.py index 5bd96d9c7e..075cfa2d54 100644 --- a/ding/rl_utils/tests/test_upgo.py +++ b/ding/rl_utils/tests/test_upgo.py @@ -1,6 +1,6 @@ import pytest import torch -from ding.rl_utils.upgo import upgo_loss, upgo_returns, tb_cross_entropy +from ding.rl_utils.upgo import upgo_data, upgo_error, upgo_returns, tb_cross_entropy @pytest.mark.unittest @@ -31,7 +31,10 @@ def test_upgo(): # upgo loss rhos = torch.randn(T, B) - loss = upgo_loss(logit, rhos, action, rewards, bootstrap_values) + dist = torch.distributions.Categorical(logits=logit) + log_prob = dist.log_prob(action) + data = upgo_data(log_prob, rhos, bootstrap_values, rewards, torch.ones_like(rewards)) + loss = upgo_error(data) assert logit.requires_grad assert bootstrap_values.requires_grad for t in [logit, bootstrap_values]: diff --git a/ding/rl_utils/upgo.py b/ding/rl_utils/upgo.py index 366d67ada4..ca333f475e 100644 --- a/ding/rl_utils/upgo.py +++ b/ding/rl_utils/upgo.py @@ -1,6 +1,6 @@ import torch import torch.nn.functional as F -from ding.hpc_rl import hpc_wrapper +from collections import namedtuple from .td import generalized_lambda_returns @@ -48,40 +48,29 @@ def upgo_returns(rewards: torch.Tensor, bootstrap_values: torch.Tensor) -> torch return generalized_lambda_returns(bootstrap_values, rewards, 1.0, lambdas) -@hpc_wrapper( - shape_fn=lambda args: args[0].shape, - namedtuple_data=True, - include_args=5, - include_kwargs=['target_output', 'rhos', 'action', 'rewards', 'bootstrap_values'] -) -def upgo_loss( - target_output: torch.Tensor, - rhos: torch.Tensor, - action: torch.Tensor, - rewards: torch.Tensor, - bootstrap_values: torch.Tensor, - mask=None -) -> torch.Tensor: - r""" +upgo_data = namedtuple('upgo_data', ['target_log_prob', 'rhos', 'bootstrap_values', 'rewards', 'weights']) + + +def upgo_error(data: namedtuple, ) -> torch.Tensor: + """ Overview: - Computing UPGO loss given constant gamma and lambda. There is no special handling for terminal state value, + Computing UPGO loss given constant gamma and lambda. There is no special handling for terminal state value, \ if the last state in trajectory is the terminal, just pass a 0 as bootstrap_terminal_value. Arguments: - - target_output (:obj:`torch.Tensor`): the output computed by the target policy network, \ + - target_log_prob (:obj:`torch.Tensor`): The output computed by the target policy network, \ of size [T_traj, batchsize, n_output] - - rhos (:obj:`torch.Tensor`): the importance sampling ratio, of size [T_traj, batchsize] - - action (:obj:`torch.Tensor`): the action taken, of size [T_traj, batchsize] - - rewards (:obj:`torch.Tensor`): the returns from time step 0 to T-1, of size [T_traj, batchsize] - - bootstrap_values (:obj:`torch.Tensor`): estimation of the state value at step 0 to T, \ + - rhos (:obj:`torch.Tensor`): The importance sampling ratio, of size [T_traj, batchsize] + - bootstrap_values (:obj:`torch.Tensor`): The estimation of the state value at step 0 to T, \ of size [T_traj+1, batchsize] + - rewards (:obj:`torch.Tensor`): The returns from time step 0 to T-1, of size [T_traj, batchsize] + - weights (:obj:`torch.Tensor`): Data weights per sample, of size [T_traj, batchsize]. Returns: - - loss (:obj:`torch.Tensor`): Computed importance sampled UPGO loss, averaged over the samples, of size [] + - loss (:obj:`torch.Tensor`): Computed importance sampled UPGO loss, averaged over the samples, 0-dim tensor. """ + target_log_prob, rhos, bootstrap_values, rewards, weights = data # discard the value at T as it should be considered in the next slice with torch.no_grad(): returns = upgo_returns(rewards, bootstrap_values) advantages = rhos * (returns - bootstrap_values[:-1]) - metric = tb_cross_entropy(target_output, action, mask) - assert (metric.shape == action.shape[:2]) - losses = advantages * metric - return -losses.mean() + loss = -advantages * target_log_prob * weights + return loss.mean() diff --git a/ding/rl_utils/vtrace.py b/ding/rl_utils/vtrace.py index b8c262e1a0..ef2761dbe0 100644 --- a/ding/rl_utils/vtrace.py +++ b/ding/rl_utils/vtrace.py @@ -69,13 +69,13 @@ def shape_fn_vtrace(args, kwargs): include_kwargs=['data', 'gamma', 'lambda_', 'rho_clip_ratio', 'c_clip_ratio', 'rho_pg_clip_ratio'] ) def vtrace_error( - data: namedtuple, - gamma: float = 0.99, - lambda_: float = 0.95, - rho_clip_ratio: float = 1.0, - c_clip_ratio: float = 1.0, - rho_pg_clip_ratio: float = 1.0 -): + data: namedtuple, + gamma: float = 0.99, + lambda_: float = 0.95, + rho_clip_ratio: float = 1.0, + c_clip_ratio: float = 1.0, + rho_pg_clip_ratio: float = 1.0 +) -> namedtuple: """ Overview: Implementation of vtrace(IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner\ @@ -124,3 +124,17 @@ def vtrace_error( value_loss = (F.mse_loss(value[:-1], return_, reduction='none') * weight).mean() entropy_loss = (dist_target.entropy() * weight).mean() return vtrace_loss(pg_loss, value_loss, entropy_loss) + + +vtrace_data_with_rho = namedtuple('vtrace_data_with_rho', ['target_log_prob', 'rho', 'value', 'reward', 'weight']) + + +def vtrace_error_with_rho(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.95) -> torch.Tensor: + target_log_prob, rho, value, reward, weight = data + with torch.no_grad(): + return_ = vtrace_nstep_return(rho, rho, reward, value, gamma, lambda_) + # pg_rhos = torch.clamp(IS, max=rho_pg_clip_ratio) + return_t_plus_1 = torch.cat([return_[1:], value[-1:]], 0) + adv = vtrace_advantage(rho, reward, return_t_plus_1, value[:-1], gamma) + loss = -adv * target_log_prob * weight + return loss.mean() diff --git a/ding/torch_utils/__init__.py b/ding/torch_utils/__init__.py index 9bd2e63a3f..7a76015f1c 100644 --- a/ding/torch_utils/__init__.py +++ b/ding/torch_utils/__init__.py @@ -1,8 +1,8 @@ from .checkpoint_helper import build_checkpoint_helper, CountVar, auto_checkpoint from .data_helper import to_device, to_tensor, to_ndarray, to_list, to_dtype, same_shape, tensor_to_list, \ - build_log_buffer, CudaFetcher, get_tensor_data, unsqueeze, get_null_data + build_log_buffer, CudaFetcher, get_tensor_data, unsqueeze, get_null_data, detach_grad, flatten from .distribution import CategoricalPd, CategoricalPdPytorch -from .metric import levenshtein_distance, hamming_distance +from .metric import l2_distance, levenshtein_distance, hamming_distance from .network import * from .loss import * from .optimizer_helper import Adam, RMSprop, calculate_grad_norm, calculate_grad_norm_without_bias_two_norm diff --git a/ding/torch_utils/data_helper.py b/ding/torch_utils/data_helper.py index 9e0b8e7861..e87e88c770 100644 --- a/ding/torch_utils/data_helper.py +++ b/ding/torch_utils/data_helper.py @@ -1,4 +1,4 @@ -from typing import Iterable, Any, Optional, List +from typing import Iterable, Any, Optional, List, Mapping from collections.abc import Sequence import numbers import time @@ -397,3 +397,32 @@ def get_null_data(template: Any, num: int) -> List[Any]: data['reward'].zero_() ret.append(data) return ret + + +def detach_grad(data): + if isinstance(data, Sequence): + for i in range(len(data)): + data[i] = detach_grad(data[i]) + elif isinstance(data, Mapping): + for k in data.keys(): + data[k] = detach_grad(data[k]) + elif isinstance(data, torch.Tensor): + data = data.detach() + else: + raise TypeError("not support data type: {}".format(type(data))) + return data + + +def flatten(data): + if isinstance(data, torch.Tensor): + return torch.flatten(data, start_dim=0, end_dim=1) # (1, (T+1) * B) + elif isinstance(data, dict): + new_data = {} + for k, val in data.items(): + new_data[k] = flatten(val) + return new_data + elif isinstance(data, Sequence): + new_data = [flatten(v) for v in data] + return new_data + else: + raise TypeError("not support data type: {}".format(type(data))) diff --git a/ding/torch_utils/metric.py b/ding/torch_utils/metric.py index 7fc2edf230..827bee4634 100644 --- a/ding/torch_utils/metric.py +++ b/ding/torch_utils/metric.py @@ -2,6 +2,16 @@ from typing import Optional, Callable +def l2_distance(a, b, min=0, max=0.8, threshold=5, spatial_x=160): + x0 = a % spatial_x + y0 = a // spatial_x + x1 = b % spatial_x + y1 = b // spatial_x + l2 = torch.sqrt((torch.square(x1 - x0) + torch.square(y1 - y0)).float()) + cost = (l2 / threshold).clamp_(min=min, max=max) + return cost + + def levenshtein_distance( pred: torch.LongTensor, target: torch.LongTensor, diff --git a/ding/torch_utils/network/__init__.py b/ding/torch_utils/network/__init__.py index 2c71ffc8a9..3d5a285f2b 100644 --- a/ding/torch_utils/network/__init__.py +++ b/ding/torch_utils/network/__init__.py @@ -1,12 +1,13 @@ -from .activation import build_activation, Swish -from .res_block import ResBlock, ResFCBlock +from .activation import build_activation, Swish, build_activation2 +from .res_block import ResBlock, ResFCBlock, GatedConvResBlock from .nn_module import fc_block, conv2d_block, one_hot, deconv2d_block, BilinearUpsample, NearestUpsample, \ - binary_encode, NoiseLinearLayer, noise_block, MLP, Flatten, normed_linear, normed_conv2d + binary_encode, NoiseLinearLayer, noise_block, MLP, Flatten, normed_linear, normed_conv2d, AttentionPool from .normalization import build_normalization from .rnn import get_lstm, sequence_mask from .soft_argmax import SoftArgmax from .transformer import Transformer -from .scatter_connection import ScatterConnection +from .scatter_connection import ScatterConnection, scatter_connection_v2 from .resnet import resnet18, ResNet from .gumbel_softmax import GumbelSoftmax from .gtrxl import GTrXL, GRUGatingUnit +from .script_lstm import script_lstm diff --git a/ding/torch_utils/network/activation.py b/ding/torch_utils/network/activation.py index 56afa03ce3..b6a6517e25 100644 --- a/ding/torch_utils/network/activation.py +++ b/ding/torch_utils/network/activation.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from ding.torch_utils.network.nn_module import fc_block, conv2d_block class GLU(nn.Module): @@ -91,3 +92,37 @@ def build_activation(activation: str, inplace: bool = None) -> nn.Module: return act_func[activation] else: raise KeyError("invalid key for activation: {}".format(activation)) + + +class GLU2(nn.Module): + r""" + Overview: + Gating Linear Unit different from GLU defined above, + Which uses fc_block instead of nn.Linear if input_type is 'fc', + and uses conv2d_block instead of nn.Conv2d if input_type is 'conv2d' + """ + + def __init__(self, input_dim, output_dim, context_dim, input_type='fc'): + super(GLU2, self).__init__() + assert (input_type in ['fc', 'conv2d']) + if input_type == 'fc': + self.layer1 = fc_block(context_dim, input_dim) + self.layer2 = fc_block(input_dim, output_dim) + elif input_type == 'conv2d': + self.layer1 = conv2d_block(context_dim, input_dim, 1, 1, 0) + self.layer2 = conv2d_block(input_dim, output_dim, 1, 1, 0) + + def forward(self, x, context): + gate = self.layer1(context) + gate = torch.sigmoid(gate) + x = gate * x + x = self.layer2(x) + return x + + +def build_activation2(activation): + act_func = {'relu': nn.ReLU(inplace=True), 'glu': GLU2, 'prelu': nn.PReLU(init=0.0)} + if activation in act_func.keys(): + return act_func[activation] + else: + raise KeyError("invalid key for activation: {}".format(activation)) diff --git a/ding/torch_utils/network/nn_module.py b/ding/torch_utils/network/nn_module.py index 784fbe3d90..719607d3ed 100644 --- a/ding/torch_utils/network/nn_module.py +++ b/ding/torch_utils/network/nn_module.py @@ -1,9 +1,9 @@ +from typing import Union, Tuple, List, Callable, Optional import math import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.init import xavier_normal_, kaiming_normal_, orthogonal_ -from typing import Union, Tuple, List, Callable from ding.compatibility import torch_ge_131 from .normalization import build_normalization @@ -214,7 +214,8 @@ def fc_block( activation: nn.Module = None, norm_type: str = None, use_dropout: bool = False, - dropout_probability: float = 0.5 + dropout_probability: float = 0.5, + init_gain: Optional[float] = None, ) -> nn.Sequential: r""" Overview: @@ -228,6 +229,7 @@ def fc_block( - norm_type (:obj:`str`): type of the normalization - use_dropout (:obj:`bool`) : whether to use dropout in the fully-connected block - dropout_probability (:obj:`float`) : probability of an element to be zeroed in the dropout. Default: 0.5 + - init_gain (:obj:`float`): FC initialization gain argument, if specified, use xavier with init_gain. Returns: - block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the fully-connected block @@ -237,7 +239,10 @@ def fc_block( """ block = [] block.append(nn.Linear(in_channels, out_channels)) - if norm_type is not None: + if init_gain is not None: + torch.nn.init.xavier_uniform_(block[-1].weight, init_gain) + torch.nn.init.constant_(block[-1].bias, 0.0) + if norm_type is not None and norm_type != 'none': block.append(build_normalization(norm_type, dim=1)(out_channels)) if activation is not None: block.append(activation) @@ -634,6 +639,41 @@ def noise_block( return sequential_pack(block) +class AttentionPool(nn.Module): + + def __init__(self, key_dim, head_num, output_dim, max_num=None): + super(AttentionPool, self).__init__() + self.queries = torch.nn.Parameter(torch.zeros(1, 1, head_num, key_dim)) + torch.nn.init.xavier_uniform_(self.queries) + self.head_num = head_num + self.add_num = False + if max_num is not None: + self.add_num = True + self.num_ebed = torch.nn.Embedding(num_embeddings=max_num, embedding_dim=output_dim) + self.embed_fc = fc_block(key_dim * self.head_num, output_dim) + + def forward(self, x, num=None, mask=None): + assert len(x.shape) == 3 # batch size, tokens, channels + x_with_head = x.unsqueeze(dim=2) # add head dim + score = x_with_head * self.queries + score = score.sum(dim=3) # b, t, h + if mask is not None: + assert len(mask.shape) == 3 and mask.shape[-1] == 1 + mask = mask.repeat(1, 1, self.head_num) + score.masked_fill_(~mask.bool(), value=-1e9) + score = F.softmax(score, dim=1) + x = x.unsqueeze(dim=3).repeat(1, 1, 1, self.head_num) # b, t, c, h + score = score.unsqueeze(dim=2) # b, t, 1, h + x = x * score + x = x.sum(dim=1) # b, c, h + x = x.view(x.shape[0], -1) # b, c * h + x = self.embed_fc(x) # b, c + if self.add_num: + x = x + F.relu(self.num_ebed(num.long())) + x = F.relu(x) + return x + + class NaiveFlatten(nn.Module): def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None: diff --git a/ding/torch_utils/network/res_block.py b/ding/torch_utils/network/res_block.py index af89bddd7e..4fe29e2908 100644 --- a/ding/torch_utils/network/res_block.py +++ b/ding/torch_utils/network/res_block.py @@ -1,7 +1,7 @@ import torch.nn as nn import torch -from .nn_module import conv2d_block, fc_block +from .nn_module import conv2d_block, fc_block, build_normalization class ResBlock(nn.Module): @@ -78,19 +78,28 @@ class ResFCBlock(nn.Module): forward ''' - def __init__(self, in_channels: int, activation: nn.Module = nn.ReLU(), norm_type: str = 'BN'): + def __init__( + self, in_channels: int, activation: nn.Module = nn.ReLU(), norm_type: str = 'BN', final_norm: bool = False + ): r""" Overview: Init the Residual Block Arguments: - in_channels (:obj:`int`): Number of channels in the input tensor - activation (:obj:`nn.Module`): the optional activation function - - norm_type (:obj:`str`): type of the normalization, defalut set to 'BN' + - norm_type (:obj:`str`): type of the normalization, default set to 'BN' + - final_norm (:obj:`bool`): Whether to add norm in final residual output. """ super(ResFCBlock, self).__init__() self.act = activation - self.fc1 = fc_block(in_channels, in_channels, activation=self.act, norm_type=norm_type) - self.fc2 = fc_block(in_channels, in_channels, activation=None, norm_type=norm_type) + self.final_norm = final_norm + if final_norm: + self.fc1 = fc_block(in_channels, in_channels, activation=self.act, norm_type=None) + self.fc2 = fc_block(in_channels, in_channels, activation=None, norm_type=None) + self.norm = build_normalization(norm_type)(in_channels) + else: + self.fc1 = fc_block(in_channels, in_channels, activation=self.act, norm_type=norm_type) + self.fc2 = fc_block(in_channels, in_channels, activation=None, norm_type=norm_type) def forward(self, x: torch.Tensor) -> torch.Tensor: r""" @@ -99,10 +108,43 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Arguments: - x (:obj:`torch.Tensor`): the input tensor Returns: - - x(:obj:`torch.Tensor`): the resblock output tensor + - x (:obj:`torch.Tensor`): the resblock output tensor """ + if self.final_norm: + residual = x + x = self.fc1(x) + x = self.fc2(x) + x = self.norm(x + residual) + return x + else: + residual = x + x = self.fc1(x) + x = self.fc2(x) + x = self.act(x + residual) + return x + + +class GatedConvResBlock(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, activation=nn.ReLU(), norm_type='BN'): + super(GatedConvResBlock, self).__init__() + assert (stride == 1), stride + assert (in_channels == out_channels), '{}/{}'.format(in_channels, out_channels) + self.act = activation + self.conv1 = conv2d_block(in_channels, out_channels, 3, 1, 1, activation=self.act, norm_type=norm_type) + self.conv2 = conv2d_block(out_channels, out_channels, 3, 1, 1, activation=None, norm_type=norm_type) + self.gate = nn.Sequential( + conv2d_block(out_channels, out_channels, 1, 1, 0, activation=self.act, norm_type=None), + conv2d_block(out_channels, out_channels, 1, 1, 0, activation=self.act, norm_type=None), + conv2d_block(out_channels, out_channels, 1, 1, 0, activation=self.act, norm_type=None), + conv2d_block(out_channels, out_channels, 1, 1, 0, activation=None, norm_type=None) + ) + self.update_sp = nn.Parameter(torch.full((1, ), fill_value=0.1)) + + def forward(self, x, noise_map): residual = x - x = self.fc1(x) - x = self.fc2(x) + x = self.conv1(x) + x = self.conv2(x) + x = torch.tanh(x * torch.sigmoid(self.gate(noise_map))) * self.update_sp x = self.act(x + residual) return x diff --git a/ding/torch_utils/network/rnn.py b/ding/torch_utils/network/rnn.py index 0f87ffbcb4..bbe3bb627c 100644 --- a/ding/torch_utils/network/rnn.py +++ b/ding/torch_utils/network/rnn.py @@ -21,23 +21,21 @@ def is_sequence(data): def sequence_mask(lengths: torch.Tensor, max_len: Optional[int] = None) -> torch.BoolTensor: - r""" + """ Overview: - create a mask for a batch sequences with different lengths + Create a mask for a batch sequences with different lengths. Arguments: - - lengths (:obj:`torch.Tensor`): lengths in each different sequences, shape could be (n, 1) or (n) - - max_len (:obj:`int`): the padding size, if max_len is None, the padding size is the \ - max length of sequences + - lengths (:obj:`torch.Tensor`): Lengths in each different sequences, shape could be (n, 1) or (n). + - max_len (:obj:`int`): The padding size, if max_len is None, the padding size is the \ + max length of sequences. Returns: - - masks (:obj:`torch.BoolTensor`): mask has the same device as lengths + - masks (:obj:`torch.BoolTensor`): Mask has the same device as lengths. """ if len(lengths.shape) == 1: lengths = lengths.unsqueeze(dim=1) bz = lengths.numel() if max_len is None: max_len = lengths.max() - else: - max_len = min(max_len, lengths.max()) return torch.arange(0, max_len).type_as(lengths).repeat(bz, 1).lt(lengths).to(lengths.device) diff --git a/ding/torch_utils/network/scatter_connection.py b/ding/torch_utils/network/scatter_connection.py index e0385fa240..58c40204aa 100644 --- a/ding/torch_utils/network/scatter_connection.py +++ b/ding/torch_utils/network/scatter_connection.py @@ -77,11 +77,13 @@ def forward(self, x: torch.Tensor, spatial_size: Tuple[int, int], location: torc device = x.device B, M, N = x.shape H, W = spatial_size - index = location.view(-1, 2) + index = location.view(-1, 2).long() bias = torch.arange(B).mul_(H * W).unsqueeze(1).repeat(1, M).view(-1).to(device) + index = index[:, 0] * W + index[:, 1] index += bias index = index.repeat(N, 1) + x = x.view(-1, N).permute(1, 0) output = torch.zeros(N, B * H * W, device=device) if self.scatter_type == 'cover': @@ -91,3 +93,29 @@ def forward(self, x: torch.Tensor, spatial_size: Tuple[int, int], location: torc output = output.reshape(N, B, H, W) output = output.permute(1, 0, 2, 3).contiguous() return output + + +def scatter_connection_v2(shape, project_embeddings, entity_location, scatter_dim, scatter_type='add'): + B, H, W = shape + device = entity_location.device + entity_num = entity_location.shape[1] + index = entity_location.view(-1, 2).long() + bias = torch.arange(B).unsqueeze(1).repeat(1, entity_num).view(-1).to(device) + bias *= H * W + index[:, 0].clamp_(0, W - 1) + index[:, 1].clamp_(0, H - 1) + index = index[:, 1] * W + index[:, 0] # entity_location: (x, y), spatial_info: (y, x) + index += bias + index = index.repeat(scatter_dim, 1) + # flat scatter map and project embeddings + scatter_map = torch.zeros(scatter_dim, B * H * W, device=device) + project_embeddings = project_embeddings.view(-1, scatter_dim).permute(1, 0) + if scatter_type == 'cover': + scatter_map.scatter_(dim=1, index=index, src=project_embeddings) + elif scatter_type == 'add': + scatter_map.scatter_add_(dim=1, index=index, src=project_embeddings) + else: + raise NotImplementedError + scatter_map = scatter_map.reshape(scatter_dim, B, H, W) + scatter_map = scatter_map.permute(1, 0, 2, 3) + return scatter_map diff --git a/ding/torch_utils/network/script_lstm.py b/ding/torch_utils/network/script_lstm.py new file mode 100644 index 0000000000..dc695e0961 --- /dev/null +++ b/ding/torch_utils/network/script_lstm.py @@ -0,0 +1,244 @@ +from typing import List, Tuple +import numbers +import torch +import torch.nn as nn +import torch.jit as jit +from torch import Tensor +from ditk import logging + + +class LSTMCell(nn.Module): + + def __init__(self, input_size, hidden_size): + super(LSTMCell, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.weight_ih = nn.Parameter(torch.randn(4 * hidden_size, input_size)) + self.weight_hh = nn.Parameter(torch.randn(4 * hidden_size, hidden_size)) + self.bias_ih = nn.Parameter(torch.randn(4 * hidden_size)) + self.bias_hh = nn.Parameter(torch.randn(4 * hidden_size)) + + def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + hx, cx = state + gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih + torch.mm(hx, self.weight_hh.t()) + self.bias_hh) + ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) + + ingate = torch.sigmoid(ingate) + forgetgate = torch.sigmoid(forgetgate) + cellgate = torch.tanh(cellgate) + outgate = torch.sigmoid(outgate) + + cy = (forgetgate * cx) + (ingate * cellgate) + hy = outgate * torch.tanh(cy) + + return hy, (hy, cy) + + +class LayerNormLSTMCell(nn.Module): + + def __init__(self, input_size, hidden_size): + super(LayerNormLSTMCell, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.weight_ih = nn.Parameter(torch.randn(4 * hidden_size, input_size)) + self.weight_hh = nn.Parameter(torch.randn(4 * hidden_size, hidden_size)) + + ln = nn.LayerNorm + + self.layernorm_i = ln(4 * hidden_size) + self.layernorm_h = ln(4 * hidden_size) + self.layernorm_c = ln(hidden_size) + + def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + hx, cx = state + igates = self.layernorm_i(torch.mm(input, self.weight_ih.t())) + hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t())) + gates = igates + hgates + ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) + + ingate = torch.sigmoid(ingate) + forgetgate = torch.sigmoid(forgetgate) + cellgate = torch.tanh(cellgate) + outgate = torch.sigmoid(outgate) + + cy = self.layernorm_c((forgetgate * cx) + (ingate * cellgate)) + hy = outgate * torch.tanh(cy) + + return hy, (hy, cy) + + +class LSTMLayer(nn.Module): + + def __init__(self, cell, *cell_args): + super(LSTMLayer, self).__init__() + self.cell = cell(*cell_args) + + def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + inputs = input.unbind(0) + outputs = torch.jit.annotate(List[Tensor], []) + for i in range(len(inputs)): + out, state = self.cell(inputs[i], state) + outputs += [out] + return torch.stack(outputs), state + + +def reverse(lst: List[Tensor]) -> List[Tensor]: + return lst[::-1] + + +class ReverseLSTMLayer(nn.Module): + + def __init__(self, cell, *cell_args): + super(ReverseLSTMLayer, self).__init__() + self.cell = cell(*cell_args) + + def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + inputs = reverse(input.unbind(0)) + outputs = jit.annotate(List[Tensor], []) + for i in range(len(inputs)): + out, state = self.cell(inputs[i], state) + outputs += [out] + return torch.stack(reverse(outputs)), state + + +class BidirLSTMLayer(nn.Module): + __constants__ = ['directions'] + + def __init__(self, cell, *cell_args): + super(BidirLSTMLayer, self).__init__() + self.directions = nn.ModuleList([ + LSTMLayer(cell, *cell_args), + ReverseLSTMLayer(cell, *cell_args), + ]) + + def forward(self, input: Tensor, states: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: + # List[LSTMState]: [forward LSTMState, backward LSTMState] + outputs = jit.annotate(List[Tensor], []) + output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) + # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 + i = 0 + for direction in self.directions: + state = states[i] + out, out_state = direction(input, state) + outputs += [out] + output_states += [out_state] + i += 1 + return torch.cat(outputs, -1), output_states + + +def init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args): + layers = [layer(*first_layer_args)] + [layer(*other_layer_args) for _ in range(num_layers - 1)] + return nn.ModuleList(layers) + + +class StackedLSTM(nn.Module): + __constants__ = ['layers'] # Necessary for iterating through self.layers + + def __init__(self, num_layers, layer, first_layer_args, other_layer_args): + super(StackedLSTM, self).__init__() + self.layers = init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args) + + def forward(self, input: Tensor, states: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: + # List[LSTMState]: One state per layer + output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) + output = input + # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 + i = 0 + for rnn_layer in self.layers: + state = states[i] + output, out_state = rnn_layer(output, state) + output_states += [out_state] + i += 1 + return output, output_states + + +# Differs from StackedLSTM in that its forward method takes +# List[List[Tuple[Tensor,Tensor]]]. It would be nice to subclass StackedLSTM +# except we don't support overriding script methods. +# https://github.com/pytorch/pytorch/issues/10733 +class StackedLSTM2(nn.Module): + __constants__ = ['layers'] # Necessary for iterating through self.layers + + def __init__(self, num_layers, layer, first_layer_args, other_layer_args): + super(StackedLSTM2, self).__init__() + self.layers = init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args) + + def forward(self, input: Tensor, + states: List[List[Tuple[Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]: + # List[List[LSTMState]]: The outer list is for layers, + # inner list is for directions. + output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], []) + output = input + # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 + i = 0 + for rnn_layer in self.layers: + state = states[i] + output, out_state = rnn_layer(output, state) + output_states += [out_state] + i += 1 + return output, output_states + + +class StackedLSTMWithDropout(nn.Module): + # Necessary for iterating through self.layers and dropout support + __constants__ = ['layers', 'num_layers'] + + def __init__(self, num_layers, layer, first_layer_args, other_layer_args): + super(StackedLSTMWithDropout, self).__init__() + self.layers = init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args) + # Introduces a Dropout layer on the outputs of each LSTM layer except + # the last layer, with dropout probability = 0.4. + self.num_layers = num_layers + + if (num_layers == 1): + logging.warning( + "dropout lstm adds dropout layers after all but last " + "recurrent layer, it expects num_layers greater than " + "1, but got num_layers = 1" + ) + + self.dropout_layer = nn.Dropout(0.4) + + def forward(self, input: Tensor, states: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: + # List[LSTMState]: One state per layer + output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) + output = input + # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 + i = 0 + for rnn_layer in self.layers: + state = states[i] + output, out_state = rnn_layer(output, state) + # Apply the dropout layer except the last layer + if i < self.num_layers - 1: + output = self.dropout_layer(output) + output_states += [out_state] + i += 1 + return output, output_states + + +def script_lstm(input_size, hidden_size, num_layers, dropout=False, bidirectional=False, LN=False): + '''Returns a ScriptModule that mimics a PyTorch native LSTM.''' + + if bidirectional: + stack_type = StackedLSTM2 + layer_type = BidirLSTMLayer + dirs = 2 + elif dropout: + stack_type = StackedLSTMWithDropout + layer_type = LSTMLayer + dirs = 1 + else: + stack_type = StackedLSTM + layer_type = LSTMLayer + dirs = 1 + if LN: + cell = LayerNormLSTMCell + else: + cell = LSTMCell + + return stack_type( + num_layers, + layer_type, + first_layer_args=[cell, input_size, hidden_size], + other_layer_args=[cell, hidden_size * dirs, hidden_size] + ) diff --git a/ding/torch_utils/network/transformer.py b/ding/torch_utils/network/transformer.py index e707134a3f..f7f31653c9 100644 --- a/ding/torch_utils/network/transformer.py +++ b/ding/torch_utils/network/transformer.py @@ -89,7 +89,7 @@ class TransformerLayer(nn.Module): def __init__( self, input_dim: int, head_dim: int, hidden_dim: int, output_dim: int, head_num: int, mlp_num: int, - dropout: nn.Module, activation: nn.Module + dropout: nn.Module, activation: nn.Module, ln_type ) -> None: r""" Overview: @@ -117,6 +117,7 @@ def __init__( layers.append(self.dropout) self.mlp = nn.Sequential(*layers) self.layernorm2 = build_normalization('LN')(output_dim) + self.ln_type = ln_type def forward(self, inputs: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -128,10 +129,22 @@ def forward(self, inputs: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tens - output (:obj:`Tuple[torch.Tensor, torch.Tensor]`): predict value and mask """ x, mask = inputs - a = self.dropout(self.attention(x, mask)) - x = self.layernorm1(x + a) - m = self.dropout(self.mlp(x)) - x = self.layernorm2(x + m) + if self.ln_type == "pre": + a = self.attention(self.layernorm1(x), mask) + if self.dropout: + a = self.dropout(a) + x = x + a + m = self.mlp(self.layernorm2(x)) + if self.dropout: + m = self.dropout(m) + x = x + m + elif self.ln_type == "post": + a = self.dropout(self.attention(x, mask)) + x = self.layernorm1(x + a) + m = self.dropout(self.mlp(x)) + x = self.layernorm2(x + m) + else: + raise NotImplementedError(self.ln_type) return (x, mask) @@ -156,6 +169,7 @@ def __init__( layer_num: int = 3, dropout_ratio: float = 0., activation: nn.Module = nn.ReLU(), + ln_type='pre' ): r""" Overview: @@ -179,7 +193,9 @@ def __init__( self.dropout = nn.Dropout(dropout_ratio) for i in range(layer_num): layers.append( - TransformerLayer(dims[i], head_dim, hidden_dim, dims[i + 1], head_num, mlp_num, self.dropout, self.act) + TransformerLayer( + dims[i], head_dim, hidden_dim, dims[i + 1], head_num, mlp_num, self.dropout, self.act, ln_type + ) ) self.main = nn.Sequential(*layers) diff --git a/ding/torch_utils/tests/test_data_helper.py b/ding/torch_utils/tests/test_data_helper.py index 218ce59ba7..11ba45184c 100644 --- a/ding/torch_utils/tests/test_data_helper.py +++ b/ding/torch_utils/tests/test_data_helper.py @@ -6,7 +6,7 @@ from torch.utils.data import DataLoader from ding.torch_utils import CudaFetcher, to_device, to_dtype, to_tensor, to_ndarray, to_list, \ - tensor_to_list, same_shape, build_log_buffer, get_tensor_data + tensor_to_list, same_shape, build_log_buffer, get_tensor_data, detach_grad, flatten from ding.utils import EasyTimer @@ -217,3 +217,37 @@ def test_to_device_cpu(setup_data_dict): other = EasyTimer() with pytest.raises(TypeError): to_device(other) + + +@pytest.mark.unittest +def test_detach_grad(): + tensor_list = [torch.tensor(1., requires_grad=True) for _ in range(4)] + tensor_dict = {'a': torch.tensor(1., requires_grad=True), 'b': torch.tensor(2., requires_grad=True)} + assert all(t.requires_grad is True for t in tensor_list) + assert all(t.requires_grad is True for _, t in tensor_dict.items()) + tensor_list = detach_grad(tensor_list) + tensor_dict = detach_grad(tensor_dict) + assert all(t.requires_grad is False for t in tensor_list) + assert all(t.requires_grad is False for _, t in tensor_dict.items()) + + with pytest.raises(TypeError): + detach_grad(1) + + +@pytest.mark.unittest +def test_flatten(): + + def test_tensor(): + return torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + + tensor_list = [test_tensor() for _ in range(4)] + tensor_dict = {'a': test_tensor(), 'b': test_tensor()} + assert all(t.shape == torch.Size([2, 2, 2]) for t in tensor_list) + assert all(t.shape == torch.Size([2, 2, 2]) for _, t in tensor_dict.items()) + tensor_list = flatten(tensor_list) + tensor_dict = flatten(tensor_dict) + assert all(t.shape == torch.Size([4, 2]) for t in tensor_list) + assert all(t.shape == torch.Size([4, 2]) for t in tensor_list) + + with pytest.raises(TypeError): + flatten(1) diff --git a/ding/utils/__init__.py b/ding/utils/__init__.py index 2dd7e7e2e6..72939b9b04 100644 --- a/ding/utils/__init__.py +++ b/ding/utils/__init__.py @@ -3,7 +3,7 @@ from .compression_helper import get_data_compressor, get_data_decompressor from .default_helper import override, dicts_to_lists, lists_to_dicts, squeeze, default_get, error_wrapper, list_split, \ LimitedSpaceContainer, deep_merge_dicts, set_pkg_seed, flatten_dict, one_time_warning, split_data_generator, \ - RunningMeanStd, make_key_as_identifier, remove_illegal_item + RunningMeanStd, make_key_as_identifier, remove_illegal_item, read_yaml_config from .design_helper import SingletonMetaclass from .file_helper import read_file, save_file, remove_file from .import_helper import try_import_ceph, try_import_mc, try_import_link, import_module, try_import_redis, \ @@ -29,6 +29,7 @@ from .type_helper import SequenceType from .render_helper import render, fps from .fast_copy import fastcopy +from .sparse_logging import log_every_n, log_every_sec if ding.enable_linklink: from .linklink_dist_helper import get_rank, get_world_size, dist_mode, dist_init, dist_finalize, \ diff --git a/ding/utils/data/collate_fn.py b/ding/utils/data/collate_fn.py index 47eda3d31c..7f27b38491 100644 --- a/ding/utils/data/collate_fn.py +++ b/ding/utils/data/collate_fn.py @@ -34,9 +34,13 @@ def inplace_fn(t): return x -def default_collate(batch: Sequence, - cat_1dim: bool = True, - ignore_prefix: list = ['collate_ignore']) -> Union[torch.Tensor, Mapping, Sequence]: +def default_collate( + batch: Sequence, + dim: int = 0, + cat_1dim: bool = True, + allow_key_mismatch: bool = False, + ignore_prefix: list = ['collate_ignore'] +) -> Union[torch.Tensor, Mapping, Sequence]: """ Overview: Put each data field into a tensor with outer dimension batch size. @@ -81,9 +85,8 @@ def default_collate(batch: Sequence, if elem.shape == (1, ) and cat_1dim: # reshape (B, 1) -> (B) return torch.cat(batch, 0, out=out) - # return torch.stack(batch, 0, out=out) else: - return torch.stack(batch, 0, out=out) + return torch.stack(batch, dim, out=out) elif isinstance(elem, ttorch.Tensor): return ttorch_collate(batch, json=True) elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ @@ -92,7 +95,7 @@ def default_collate(batch: Sequence, # array of string classes and object if np_str_obj_array_pattern.search(elem.dtype.str) is not None: raise TypeError(default_collate_err_msg_format.format(elem.dtype)) - return default_collate([torch.as_tensor(b) for b in batch], cat_1dim=cat_1dim) + return default_collate([torch.as_tensor(b) for b in batch], dim=dim, cat_1dim=cat_1dim) elif elem.shape == (): # scalars return torch.as_tensor(batch) elif isinstance(elem, float): @@ -105,16 +108,22 @@ def default_collate(batch: Sequence, elif isinstance(elem, container_abcs.Mapping): ret = {} for key in elem: - if any([key.startswith(t) for t in ignore_prefix]): - ret[key] = [d[key] for d in batch] + if allow_key_mismatch: + if any([key.startswith(t) for t in ignore_prefix]): + ret[key] = [d[key] for d in batch if key in d.keys()] + else: + ret[key] = default_collate([d[key] for d in batch if key in d.keys()], dim=dim, cat_1dim=cat_1dim) else: - ret[key] = default_collate([d[key] for d in batch], cat_1dim=cat_1dim) + if any([key.startswith(t) for t in ignore_prefix]): + ret[key] = [d[key] for d in batch] + else: + ret[key] = default_collate([d[key] for d in batch], dim=dim, cat_1dim=cat_1dim) return ret elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple - return elem_type(*(default_collate(samples, cat_1dim=cat_1dim) for samples in zip(*batch))) + return elem_type(*(default_collate(samples, dim=dim, cat_1dim=cat_1dim) for samples in zip(*batch))) elif isinstance(elem, container_abcs.Sequence): transposed = zip(*batch) - return [default_collate(samples, cat_1dim=cat_1dim) for samples in transposed] + return [default_collate(samples, dim=dim, cat_1dim=cat_1dim) for samples in transposed] raise TypeError(default_collate_err_msg_format.format(elem_type)) diff --git a/ding/utils/data/tests/test_collate_fn.py b/ding/utils/data/tests/test_collate_fn.py index 83611377c1..97b4b96bd6 100644 --- a/ding/utils/data/tests/test_collate_fn.py +++ b/ding/utils/data/tests/test_collate_fn.py @@ -105,6 +105,18 @@ def test_basic(self): assert isinstance(data, dict) assert len(data['collate_ignore_data']) == 4 + def test_dim_attirbute(self): + + def test_tensor(): + return torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + + test_batch_tensor = [test_tensor() for _ in range(4)] + assert test_batch_tensor[0].shape == torch.Size([2, 2, 2]) + assert default_collate(test_batch_tensor, dim=0).shape == torch.Size([4, 2, 2, 2]) + assert default_collate(test_batch_tensor, dim=1).shape == torch.Size([2, 4, 2, 2]) + assert default_collate(test_batch_tensor, dim=2).shape == torch.Size([2, 2, 4, 2]) + assert default_collate(test_batch_tensor, dim=3).shape == torch.Size([2, 2, 2, 4]) + @pytest.mark.unittest class TestDefaultDecollate: diff --git a/ding/utils/default_helper.py b/ding/utils/default_helper.py index 108c08d58c..2b80d46d53 100644 --- a/ding/utils/default_helper.py +++ b/ding/utils/default_helper.py @@ -1,10 +1,13 @@ from typing import Union, Mapping, List, NamedTuple, Tuple, Callable, Optional, Any, Dict +from functools import lru_cache # in python3.9, we can change to cache +from easydict import EasyDict +import os import copy -from ditk import logging import random -from functools import lru_cache # in python3.9, we can change to cache import numpy as np import torch +import yaml +from ditk import logging def lists_to_dicts( @@ -571,6 +574,21 @@ def legalization(s: str) -> str: return new_data +def read_yaml_config(path: str) -> EasyDict: + """ + Overview: + Read yaml configuration from given path. + Arguments: + - path (:obj:`str`): Path of source yaml. + Returns: + - cfg (:obj:`EasyDict`): Config data from this file with dict type. + """ + assert os.path.exists(path), path + with open(path, "r") as f: + config = yaml.safe_load(f) + return EasyDict(config) + + def remove_illegal_item(data: Dict[str, Any]) -> Dict[str, Any]: """ Overview: diff --git a/ding/utils/sparse_logging.py b/ding/utils/sparse_logging.py new file mode 100644 index 0000000000..d54f77d7c5 --- /dev/null +++ b/ding/utils/sparse_logging.py @@ -0,0 +1,78 @@ +from ditk import logging +import itertools +import timeit + +_log_counter_per_token = {} +_log_timer_per_token = {} + + +def _get_next_log_count_per_token(token): + """ + Overview: + Wrapper for _log_counter_per_token. Thread-safe. + Arguments: + - token: The token for which to look up the count. + Returns: + - ret: The number of times this function has been called + with *token* as an argument (starting at 0). + """ + # Can't use a defaultdict because defaultdict isn't atomic, whereas + # setdefault is. + return next(_log_counter_per_token.setdefault(token, itertools.count())) + + +def _seconds_have_elapsed(token, num_seconds): + """ + Overview: + Tests if 'num_seconds' have passed since 'token' was requested. + Not strictly thread-safe - may log with the wrong frequency if called + concurrently from multiple threads. Accuracy depends on resolution of + 'timeit.default_timer()'. + Always returns True on the first call for a given 'token'. + Arguments: + - token: The token for which to look up the count. + - num_seconds: The number of seconds to test for. + Returns: + - ret: Whether it has been >= 'num_seconds' since 'token' was last requested. + """ + now = timeit.default_timer() + then = _log_timer_per_token.get(token, None) + if then is None or (now - then) >= num_seconds: + _log_timer_per_token[token] = now + return True + else: + return False + + +def log_every_n(level, n, msg, *args): + """ + Overview: + Logs 'msg % args' at level 'level' once per 'n' times. + Logs the 1st call, (N+1)st call, (2N+1)st call, etc. + Not threadsafe. + Arguments: + - level (:obj:`int`): the absl logging level at which to log. + - msg (:obj:`str`): the message to be logged. + - n (:obj:`int`): the number of times this should be called before it is logged. + - *args: The args to be substituted into the msg. + """ + count = _get_next_log_count_per_token(logging.getLogger().findCaller()) + if count % n == 0: + logging.log(level, msg, *args) + + +def log_every_sec(level, n_seconds, msg, *args): + """ + Overview: + Logs 'msg % args' at level 'level' iff 'n_seconds' elapsed since last call. + Logs the first call, logs subsequent calls if 'n' seconds have elapsed since + the last logging call from the same call site (file + line). Not thread-safe. + Arguments: + - level (:obj:`int`): the absl logging level at which to log. + - msg (:obj:`str`): the message to be logged. + - n_seconds (:obj:`Union[int, float]`): seconds which should elapse before logging again. + - *args: The args to be substituted into the msg. + """ + should_log = _seconds_have_elapsed(logging.getLogger().findCaller(), n_seconds) + if should_log: + logging.log(level, msg, *args) diff --git a/ding/utils/tests/test_sparse_logging.py b/ding/utils/tests/test_sparse_logging.py new file mode 100644 index 0000000000..47b7085157 --- /dev/null +++ b/ding/utils/tests/test_sparse_logging.py @@ -0,0 +1,15 @@ +import pytest +import logging +import time +from ding.utils import log_every_n, log_every_sec + + +@pytest.mark.unittest +def test_sparse_logging(): + logging.getLogger().setLevel(logging.INFO) + for i in range(30): + log_every_n(logging.INFO, 5, "abc_{}".format(i)) + + for i in range(30): + time.sleep(0.1) + log_every_sec(logging.INFO, 1, "abc_{}".format(i))