Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(nyz): fix confusing shallow copy operation bug about next_obs #641

Merged
merged 1 commit into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions ding/worker/collector/base_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,28 +197,33 @@ def append(self, data: Any) -> None:
super().append(data)


def to_tensor_transitions(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
def to_tensor_transitions(data: List[Dict[str, Any]], shallow_copy_next_obs: bool = True) -> List[Dict[str, Any]]:
"""
Overview:
transitions collected data to tensor.
Transform ths original transition return from env to tensor format.
Argument:
- data (:obj:`List[Dict[str, Any]]`): the data that will be transited to tensor.
- data (:obj:`List[Dict[str, Any]]`): The data that will be transformed to tensor.
- shallow_copy_next_obs (:obj:`bool`): Whether to shallow copy next_obs. Default: True.
Return:
- data (:obj:`List[Dict[str, Any]]`): the data that can be transited to tensor.
- data (:obj:`List[Dict[str, Any]]`): The transformed tensor-like data.

.. tip::
In order to save memory, If there are next_obs in the passed data, we do special \
treatment on next_obs so that the next_obs of each state in the data fragment is \
the next state's obs and the next_obs of the last state is its own next_obs, \
and we make transform_scalar is False.
treatment on next_obs so that the next_obs of each state in the data fragment is \
the next state's obs and the next_obs of the last state is its own next_obsself. \
Besides, we set transform_scalar to False to avoid the extra ``.item()`` operation.
"""
if 'next_obs' not in data[0]:
return to_tensor(data, transform_scalar=False)
else:
# for save memory of next_obs
data = to_tensor(data, transform_scalar=False)
data = to_tensor(data, ignore_keys=['next_obs'], transform_scalar=False)
for i in range(len(data) - 1):
data[i]['next_obs'] = data[i + 1]['obs']
data[-1]['next_obs'] = to_tensor(data[-1]['next_obs'], transform_scalar=False)
# to_tensor will assign the separate memory to next_obs, if shallow_copy_next_obs is True,
# we can add ignore_keys to avoid this data copy for saving memory of next_obs.
if shallow_copy_next_obs:
data = to_tensor(data, ignore_keys=['next_obs'], transform_scalar=False)
for i in range(len(data) - 1):
data[i]['next_obs'] = data[i + 1]['obs']
data[-1]['next_obs'] = to_tensor(data[-1]['next_obs'], transform_scalar=False)
return data
else:
data = to_tensor(data, transform_scalar=False)
return data
4 changes: 3 additions & 1 deletion ding/worker/collector/battle_episode_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,9 @@ def collect(self,
self._traj_buffer[env_id][policy_id].append(transition)
# prepare data
if timestep.done:
transitions = to_tensor_transitions(self._traj_buffer[env_id][policy_id])
transitions = to_tensor_transitions(
self._traj_buffer[env_id][policy_id], not self._deepcopy_obs
)
if self._cfg.get_train_sample:
train_sample = self._policy[policy_id].get_train_sample(transitions)
return_data[policy_id].extend(train_sample)
Expand Down
4 changes: 3 additions & 1 deletion ding/worker/collector/battle_sample_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,9 @@ def collect(
self._traj_buffer[env_id][policy_id].append(transition)
# prepare data
if timestep.done or len(self._traj_buffer[env_id][policy_id]) == self._traj_len:
transitions = to_tensor_transitions(self._traj_buffer[env_id][policy_id])
transitions = to_tensor_transitions(
self._traj_buffer[env_id][policy_id], not self._deepcopy_obs
)
train_sample = self._policy[policy_id].get_train_sample(transitions)
return_data[policy_id].extend(train_sample)
self._total_train_sample_count += len(train_sample)
Expand Down
2 changes: 1 addition & 1 deletion ding/worker/collector/episode_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def collect(self,
self._total_envstep_count += 1
# prepare data
if timestep.done:
transitions = to_tensor_transitions(self._traj_buffer[env_id])
transitions = to_tensor_transitions(self._traj_buffer[env_id], not self._deepcopy_obs)
if self._cfg.reward_shaping:
self._env.reward_shaping(env_id, transitions)
if self._cfg.get_train_sample:
Expand Down
5 changes: 3 additions & 2 deletions ding/worker/collector/sample_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
self._exp_name = exp_name
self._instance_name = instance_name
self._collect_print_freq = cfg.collect_print_freq
self._deepcopy_obs = cfg.deepcopy_obs # avoid shallow copy, e.g., ovelap of s_t and s_t+1
self._deepcopy_obs = cfg.deepcopy_obs # whether to deepcopy each data
self._transform_obs = cfg.transform_obs
self._cfg = cfg
self._timer = EasyTimer()
Expand Down Expand Up @@ -288,7 +288,8 @@ def collect(
# sequence sample of length <burnin + learn_unroll_len> (please refer to r2d2.py).

# Episode is done or traj_buffer(maxlen=traj_len) is full.
transitions = to_tensor_transitions(self._traj_buffer[env_id])
# indicate whether to shallow copy next obs, i.e., overlap of s_t and s_t+1
transitions = to_tensor_transitions(self._traj_buffer[env_id], not self._deepcopy_obs)
train_sample = self._policy.get_train_sample(transitions)
return_data.extend(train_sample)
self._total_train_sample_count += len(train_sample)
Expand Down
42 changes: 42 additions & 0 deletions ding/worker/collector/tests/test_base_serial_collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest
import numpy as np
import torch
from ding.worker.collector.base_serial_collector import to_tensor_transitions


def get_transition():
return {
'obs': np.random.random((2, 3)),
'action': np.random.randint(0, 6, size=(1, )),
'reward': np.random.random((1, )),
'done': False,
'next_obs': np.random.random((2, 3)),
}


@pytest.mark.unittest
def test_to_tensor_transitions():
# test case when shallow copy is True
transition_list = [get_transition() for _ in range(4)]
tensor_list = to_tensor_transitions(transition_list, shallow_copy_next_obs=True)
for i in range(len(tensor_list)):
tensor = tensor_list[i]
assert isinstance(tensor['obs'], torch.Tensor)
assert isinstance(tensor['action'], torch.Tensor), type(tensor['action'])
assert isinstance(tensor['reward'], torch.Tensor)
assert isinstance(tensor['done'], bool)
assert 'next_obs' in tensor
if i < len(tensor_list) - 1:
assert id(tensor['next_obs']) == id(tensor_list[i + 1]['obs'])
# test case when shallow copy is False
transition_list = [get_transition() for _ in range(4)]
tensor_list = to_tensor_transitions(transition_list, shallow_copy_next_obs=False)
for i in range(len(tensor_list)):
tensor = tensor_list[i]
assert isinstance(tensor['obs'], torch.Tensor)
assert isinstance(tensor['action'], torch.Tensor)
assert isinstance(tensor['reward'], torch.Tensor)
assert isinstance(tensor['done'], bool)
assert 'next_obs' in tensor
if i < len(tensor_list) - 1:
assert id(tensor['next_obs']) != id(tensor_list[i + 1]['obs'])