diff --git a/chainerrl/replay_buffer.py b/chainerrl/replay_buffer.py index 43981a367..fce0e6559 100644 --- a/chainerrl/replay_buffer.py +++ b/chainerrl/replay_buffer.py @@ -228,7 +228,9 @@ def batch_recurrent_experiences( 'next_recurrent_state': model.concatenate_recurrent_states( [ep[0]['next_recurrent_state'] for ep in experiences]), } - if all(elem[-1]['next_action'] is not None for elem in experiences): + # Batch next actions only when all the transitions have them + if all(transition['next_action'] is not None + for transition in flat_transitions): batch_exp['next_action'] = xp.asarray( [transition['next_action'] for transition in flat_transitions]) return batch_exp