diff --git a/ding/worker/collector/base_serial_collector.py b/ding/worker/collector/base_serial_collector.py index 75f4bbe8be..caa88410af 100644 --- a/ding/worker/collector/base_serial_collector.py +++ b/ding/worker/collector/base_serial_collector.py @@ -197,28 +197,33 @@ def append(self, data: Any) -> None: super().append(data) -def to_tensor_transitions(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: +def to_tensor_transitions(data: List[Dict[str, Any]], shallow_copy_next_obs: bool = True) -> List[Dict[str, Any]]: """ Overview: - transitions collected data to tensor. + Transform ths original transition return from env to tensor format. Argument: - - data (:obj:`List[Dict[str, Any]]`): the data that will be transited to tensor. + - data (:obj:`List[Dict[str, Any]]`): The data that will be transformed to tensor. + - shallow_copy_next_obs (:obj:`bool`): Whether to shallow copy next_obs. Default: True. Return: - - data (:obj:`List[Dict[str, Any]]`): the data that can be transited to tensor. + - data (:obj:`List[Dict[str, Any]]`): The transformed tensor-like data. .. tip:: In order to save memory, If there are next_obs in the passed data, we do special \ - treatment on next_obs so that the next_obs of each state in the data fragment is \ - the next state's obs and the next_obs of the last state is its own next_obs, \ - and we make transform_scalar is False. + treatment on next_obs so that the next_obs of each state in the data fragment is \ + the next state's obs and the next_obs of the last state is its own next_obsself. \ + Besides, we set transform_scalar to False to avoid the extra ``.item()`` operation. """ if 'next_obs' not in data[0]: return to_tensor(data, transform_scalar=False) else: - # for save memory of next_obs - data = to_tensor(data, transform_scalar=False) - data = to_tensor(data, ignore_keys=['next_obs'], transform_scalar=False) - for i in range(len(data) - 1): - data[i]['next_obs'] = data[i + 1]['obs'] - data[-1]['next_obs'] = to_tensor(data[-1]['next_obs'], transform_scalar=False) + # to_tensor will assign the separate memory to next_obs, if shallow_copy_next_obs is True, + # we can add ignore_keys to avoid this data copy for saving memory of next_obs. + if shallow_copy_next_obs: + data = to_tensor(data, ignore_keys=['next_obs'], transform_scalar=False) + for i in range(len(data) - 1): + data[i]['next_obs'] = data[i + 1]['obs'] + data[-1]['next_obs'] = to_tensor(data[-1]['next_obs'], transform_scalar=False) + return data + else: + data = to_tensor(data, transform_scalar=False) return data diff --git a/ding/worker/collector/battle_episode_serial_collector.py b/ding/worker/collector/battle_episode_serial_collector.py index aa39de1669..6609adcaea 100644 --- a/ding/worker/collector/battle_episode_serial_collector.py +++ b/ding/worker/collector/battle_episode_serial_collector.py @@ -260,7 +260,9 @@ def collect(self, self._traj_buffer[env_id][policy_id].append(transition) # prepare data if timestep.done: - transitions = to_tensor_transitions(self._traj_buffer[env_id][policy_id]) + transitions = to_tensor_transitions( + self._traj_buffer[env_id][policy_id], not self._deepcopy_obs + ) if self._cfg.get_train_sample: train_sample = self._policy[policy_id].get_train_sample(transitions) return_data[policy_id].extend(train_sample) diff --git a/ding/worker/collector/battle_sample_serial_collector.py b/ding/worker/collector/battle_sample_serial_collector.py index e047708167..dffc43f5f7 100644 --- a/ding/worker/collector/battle_sample_serial_collector.py +++ b/ding/worker/collector/battle_sample_serial_collector.py @@ -276,7 +276,9 @@ def collect( self._traj_buffer[env_id][policy_id].append(transition) # prepare data if timestep.done or len(self._traj_buffer[env_id][policy_id]) == self._traj_len: - transitions = to_tensor_transitions(self._traj_buffer[env_id][policy_id]) + transitions = to_tensor_transitions( + self._traj_buffer[env_id][policy_id], not self._deepcopy_obs + ) train_sample = self._policy[policy_id].get_train_sample(transitions) return_data[policy_id].extend(train_sample) self._total_train_sample_count += len(train_sample) diff --git a/ding/worker/collector/episode_serial_collector.py b/ding/worker/collector/episode_serial_collector.py index ed9b1fd6f2..6fca2283f8 100644 --- a/ding/worker/collector/episode_serial_collector.py +++ b/ding/worker/collector/episode_serial_collector.py @@ -253,7 +253,7 @@ def collect(self, self._total_envstep_count += 1 # prepare data if timestep.done: - transitions = to_tensor_transitions(self._traj_buffer[env_id]) + transitions = to_tensor_transitions(self._traj_buffer[env_id], not self._deepcopy_obs) if self._cfg.reward_shaping: self._env.reward_shaping(env_id, transitions) if self._cfg.get_train_sample: diff --git a/ding/worker/collector/sample_serial_collector.py b/ding/worker/collector/sample_serial_collector.py index a8311c2c63..fd6bddea40 100644 --- a/ding/worker/collector/sample_serial_collector.py +++ b/ding/worker/collector/sample_serial_collector.py @@ -47,7 +47,7 @@ def __init__( self._exp_name = exp_name self._instance_name = instance_name self._collect_print_freq = cfg.collect_print_freq - self._deepcopy_obs = cfg.deepcopy_obs # avoid shallow copy, e.g., ovelap of s_t and s_t+1 + self._deepcopy_obs = cfg.deepcopy_obs # whether to deepcopy each data self._transform_obs = cfg.transform_obs self._cfg = cfg self._timer = EasyTimer() @@ -288,7 +288,8 @@ def collect( # sequence sample of length (please refer to r2d2.py). # Episode is done or traj_buffer(maxlen=traj_len) is full. - transitions = to_tensor_transitions(self._traj_buffer[env_id]) + # indicate whether to shallow copy next obs, i.e., overlap of s_t and s_t+1 + transitions = to_tensor_transitions(self._traj_buffer[env_id], not self._deepcopy_obs) train_sample = self._policy.get_train_sample(transitions) return_data.extend(train_sample) self._total_train_sample_count += len(train_sample) diff --git a/ding/worker/collector/tests/test_base_serial_collector.py b/ding/worker/collector/tests/test_base_serial_collector.py new file mode 100644 index 0000000000..475a6a4b17 --- /dev/null +++ b/ding/worker/collector/tests/test_base_serial_collector.py @@ -0,0 +1,42 @@ +import pytest +import numpy as np +import torch +from ding.worker.collector.base_serial_collector import to_tensor_transitions + + +def get_transition(): + return { + 'obs': np.random.random((2, 3)), + 'action': np.random.randint(0, 6, size=(1, )), + 'reward': np.random.random((1, )), + 'done': False, + 'next_obs': np.random.random((2, 3)), + } + + +@pytest.mark.unittest +def test_to_tensor_transitions(): + # test case when shallow copy is True + transition_list = [get_transition() for _ in range(4)] + tensor_list = to_tensor_transitions(transition_list, shallow_copy_next_obs=True) + for i in range(len(tensor_list)): + tensor = tensor_list[i] + assert isinstance(tensor['obs'], torch.Tensor) + assert isinstance(tensor['action'], torch.Tensor), type(tensor['action']) + assert isinstance(tensor['reward'], torch.Tensor) + assert isinstance(tensor['done'], bool) + assert 'next_obs' in tensor + if i < len(tensor_list) - 1: + assert id(tensor['next_obs']) == id(tensor_list[i + 1]['obs']) + # test case when shallow copy is False + transition_list = [get_transition() for _ in range(4)] + tensor_list = to_tensor_transitions(transition_list, shallow_copy_next_obs=False) + for i in range(len(tensor_list)): + tensor = tensor_list[i] + assert isinstance(tensor['obs'], torch.Tensor) + assert isinstance(tensor['action'], torch.Tensor) + assert isinstance(tensor['reward'], torch.Tensor) + assert isinstance(tensor['done'], bool) + assert 'next_obs' in tensor + if i < len(tensor_list) - 1: + assert id(tensor['next_obs']) != id(tensor_list[i + 1]['obs'])