Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[BUGFIX] Fix _get_rnn_cell #648

Merged
merged 2 commits into from
May 8, 2019
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/gluonnlp/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,11 @@ def _get_rnn_cell(mode, num_layers, input_size, hidden_size,
Only available when the mode=lstmpc.
"""

assert mode == 'lstmpc' and proj_size is not None, \
assert mode == 'lstmpc' or proj_size is None, \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! proj_size is required for lstmpc. One simple modification would be removing mode == lstmpc, and add a condition if mode == 'lstmpc'.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry in the last review I didn't realize this change will break lstmpc cell. What about moving the assertion inside LSTMPCellWithClip to check proj_size, cell_clip and proj_clip ?

super(LSTMPCellWithClip, self).__init__(hidden_size,
projection_size,
i2h_weight_initializer,
h2h_weight_initializer,
h2r_weight_initializer,
i2h_bias_initializer,
h2h_bias_initializer,
input_size,
prefix=prefix,
params=params)
self._cell_clip = cell_clip
self._projection_clip = projection_clip

'proj_size takes effect only when mode is lstmpc'
assert mode == 'lstmpc' and cell_clip is not None, \
assert mode == 'lstmpc' or cell_clip is None, \
'cell_clip takes effect only when mode is lstmpc'
assert mode == 'lstmpc' and proj_clip is not None, \
assert mode == 'lstmpc' or proj_clip is None, \
'proj_clip takes effect only when mode is lstmpc'

rnn_cell = rnn.HybridSequentialRNNCell()
Expand Down