Skip to content

Commit

Permalink
fix(nyz): fix gtrxl compatibility bug (#796)
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed May 28, 2024
1 parent b2aab8d commit 13a6d45
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 11 deletions.
3 changes: 2 additions & 1 deletion ding/model/template/q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,8 +1118,9 @@ def __init__(
gru_bias=gru_bias,
)

# for vector obs, use Identity Encoder, i.e. pass
if isinstance(obs_shape, int) or len(obs_shape) == 1:
raise NotImplementedError("not support obs_shape for pre-defined encoder: {}".format(obs_shape))
pass
# replace the embedding layer of Transformer with Conv Encoder
elif len(obs_shape) == 3:
assert encoder_hidden_size_list[-1] == hidden_size
Expand Down
11 changes: 4 additions & 7 deletions ding/policy/r2d2_gtrxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,9 @@ class R2D2GTrXLPolicy(Policy):
| ``done`` | calculation. | fake termination env
15 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from
| call of collector. | different envs
16 | ``collect.unroll`` int 25 | unroll length of an iteration | unroll_len>1
16 | ``collect.seq`` int 20 | Training sequence length | unroll_len>=seq_len>1
| ``_len``
17 | ``collect.seq`` int 20 | Training sequence length | unroll_len>=seq_len>1
| ``_len``
18 | ``learn.init_`` str zero | 'zero' or 'old', how to initialize the |
17 | ``learn.init_`` str zero | 'zero' or 'old', how to initialize the |
| ``memory`` | memory before each training iteration. |
== ==================== ======== ============== ======================================== =======================
"""
Expand All @@ -81,7 +79,7 @@ class R2D2GTrXLPolicy(Policy):
discount_factor=0.99,
# (int) N-step reward for target q_value estimation
nstep=5,
# how many steps to use as burnin
# (int) How many steps to use in burnin phase
burnin_step=1,
# (int) trajectory length
unroll_len=25,
Expand Down Expand Up @@ -158,7 +156,7 @@ def _init_learn(self) -> None:
self._seq_len = self._cfg.seq_len
self._value_rescale = self._cfg.learn.value_rescale
self._init_memory = self._cfg.learn.init_memory
assert self._init_memory in ['zero', 'old']
assert self._init_memory in ['zero', 'old'], self._init_memory

self._target_model = copy.deepcopy(self._model)

Expand Down Expand Up @@ -352,7 +350,6 @@ def _init_collect(self) -> None:
Collect mode init method. Called by ``self.__init__``.
Init unroll length and sequence len, collect model.
"""
assert 'unroll_len' not in self._cfg.collect, "Use default unroll_len"
self._nstep = self._cfg.nstep
self._gamma = self._cfg.discount_factor
self._unroll_len = self._cfg.unroll_len
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
obs_shape=4,
action_shape=2,
memory_len=5, # length of transformer memory (can be 0)
hidden_size=256,
hidden_size=64,
gru_bias=2.,
att_layer_num=3,
dropout=0.,
att_head_num=8,
att_head_num=4,
),
discount_factor=0.99,
nstep=3,
Expand All @@ -31,7 +31,7 @@
seq_len=8, # transformer input segment
# training sequence: unroll_len - burnin_step - nstep
learn=dict(
update_per_collect=8,
update_per_collect=16,
batch_size=64,
learning_rate=0.0005,
target_update_freq=500,
Expand Down

0 comments on commit 13a6d45

Please sign in to comment.