Skip to content

Commit

Permalink
fix(nyz): fix confusing shallow copy operation about next_obs (#641)
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 authored Apr 11, 2023
1 parent 1cb1038 commit 283ef35
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 18 deletions.
31 changes: 18 additions & 13 deletions ding/worker/collector/base_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,28 +197,33 @@ def append(self, data: Any) -> None:
super().append(data)


def to_tensor_transitions(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
def to_tensor_transitions(data: List[Dict[str, Any]], shallow_copy_next_obs: bool = True) -> List[Dict[str, Any]]:
"""
Overview:
transitions collected data to tensor.
Transform ths original transition return from env to tensor format.
Argument:
- data (:obj:`List[Dict[str, Any]]`): the data that will be transited to tensor.
- data (:obj:`List[Dict[str, Any]]`): The data that will be transformed to tensor.
- shallow_copy_next_obs (:obj:`bool`): Whether to shallow copy next_obs. Default: True.
Return:
- data (:obj:`List[Dict[str, Any]]`): the data that can be transited to tensor.
- data (:obj:`List[Dict[str, Any]]`): The transformed tensor-like data.
.. tip::
In order to save memory, If there are next_obs in the passed data, we do special \
treatment on next_obs so that the next_obs of each state in the data fragment is \
the next state's obs and the next_obs of the last state is its own next_obs, \
and we make transform_scalar is False.
treatment on next_obs so that the next_obs of each state in the data fragment is \
the next state's obs and the next_obs of the last state is its own next_obsself. \
Besides, we set transform_scalar to False to avoid the extra ``.item()`` operation.
"""
if 'next_obs' not in data[0]:
return to_tensor(data, transform_scalar=False)
else:
# for save memory of next_obs
data = to_tensor(data, transform_scalar=False)
data = to_tensor(data, ignore_keys=['next_obs'], transform_scalar=False)
for i in range(len(data) - 1):
data[i]['next_obs'] = data[i + 1]['obs']
data[-1]['next_obs'] = to_tensor(data[-1]['next_obs'], transform_scalar=False)
# to_tensor will assign the separate memory to next_obs, if shallow_copy_next_obs is True,
# we can add ignore_keys to avoid this data copy for saving memory of next_obs.
if shallow_copy_next_obs:
data = to_tensor(data, ignore_keys=['next_obs'], transform_scalar=False)
for i in range(len(data) - 1):
data[i]['next_obs'] = data[i + 1]['obs']
data[-1]['next_obs'] = to_tensor(data[-1]['next_obs'], transform_scalar=False)
return data
else:
data = to_tensor(data, transform_scalar=False)
return data
4 changes: 3 additions & 1 deletion ding/worker/collector/battle_episode_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,9 @@ def collect(self,
self._traj_buffer[env_id][policy_id].append(transition)
# prepare data
if timestep.done:
transitions = to_tensor_transitions(self._traj_buffer[env_id][policy_id])
transitions = to_tensor_transitions(
self._traj_buffer[env_id][policy_id], not self._deepcopy_obs
)
if self._cfg.get_train_sample:
train_sample = self._policy[policy_id].get_train_sample(transitions)
return_data[policy_id].extend(train_sample)
Expand Down
4 changes: 3 additions & 1 deletion ding/worker/collector/battle_sample_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,9 @@ def collect(
self._traj_buffer[env_id][policy_id].append(transition)
# prepare data
if timestep.done or len(self._traj_buffer[env_id][policy_id]) == self._traj_len:
transitions = to_tensor_transitions(self._traj_buffer[env_id][policy_id])
transitions = to_tensor_transitions(
self._traj_buffer[env_id][policy_id], not self._deepcopy_obs
)
train_sample = self._policy[policy_id].get_train_sample(transitions)
return_data[policy_id].extend(train_sample)
self._total_train_sample_count += len(train_sample)
Expand Down
2 changes: 1 addition & 1 deletion ding/worker/collector/episode_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def collect(self,
self._total_envstep_count += 1
# prepare data
if timestep.done:
transitions = to_tensor_transitions(self._traj_buffer[env_id])
transitions = to_tensor_transitions(self._traj_buffer[env_id], not self._deepcopy_obs)
if self._cfg.reward_shaping:
self._env.reward_shaping(env_id, transitions)
if self._cfg.get_train_sample:
Expand Down
5 changes: 3 additions & 2 deletions ding/worker/collector/sample_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
self._exp_name = exp_name
self._instance_name = instance_name
self._collect_print_freq = cfg.collect_print_freq
self._deepcopy_obs = cfg.deepcopy_obs # avoid shallow copy, e.g., ovelap of s_t and s_t+1
self._deepcopy_obs = cfg.deepcopy_obs # whether to deepcopy each data
self._transform_obs = cfg.transform_obs
self._cfg = cfg
self._timer = EasyTimer()
Expand Down Expand Up @@ -288,7 +288,8 @@ def collect(
# sequence sample of length <burnin + learn_unroll_len> (please refer to r2d2.py).

# Episode is done or traj_buffer(maxlen=traj_len) is full.
transitions = to_tensor_transitions(self._traj_buffer[env_id])
# indicate whether to shallow copy next obs, i.e., overlap of s_t and s_t+1
transitions = to_tensor_transitions(self._traj_buffer[env_id], not self._deepcopy_obs)
train_sample = self._policy.get_train_sample(transitions)
return_data.extend(train_sample)
self._total_train_sample_count += len(train_sample)
Expand Down
42 changes: 42 additions & 0 deletions ding/worker/collector/tests/test_base_serial_collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest
import numpy as np
import torch
from ding.worker.collector.base_serial_collector import to_tensor_transitions


def get_transition():
return {
'obs': np.random.random((2, 3)),
'action': np.random.randint(0, 6, size=(1, )),
'reward': np.random.random((1, )),
'done': False,
'next_obs': np.random.random((2, 3)),
}


@pytest.mark.unittest
def test_to_tensor_transitions():
# test case when shallow copy is True
transition_list = [get_transition() for _ in range(4)]
tensor_list = to_tensor_transitions(transition_list, shallow_copy_next_obs=True)
for i in range(len(tensor_list)):
tensor = tensor_list[i]
assert isinstance(tensor['obs'], torch.Tensor)
assert isinstance(tensor['action'], torch.Tensor), type(tensor['action'])
assert isinstance(tensor['reward'], torch.Tensor)
assert isinstance(tensor['done'], bool)
assert 'next_obs' in tensor
if i < len(tensor_list) - 1:
assert id(tensor['next_obs']) == id(tensor_list[i + 1]['obs'])
# test case when shallow copy is False
transition_list = [get_transition() for _ in range(4)]
tensor_list = to_tensor_transitions(transition_list, shallow_copy_next_obs=False)
for i in range(len(tensor_list)):
tensor = tensor_list[i]
assert isinstance(tensor['obs'], torch.Tensor)
assert isinstance(tensor['action'], torch.Tensor)
assert isinstance(tensor['reward'], torch.Tensor)
assert isinstance(tensor['done'], bool)
assert 'next_obs' in tensor
if i < len(tensor_list) - 1:
assert id(tensor['next_obs']) != id(tensor_list[i + 1]['obs'])

0 comments on commit 283ef35

Please sign in to comment.