diff --git a/ivy/functional/backends/torch/layers.py b/ivy/functional/backends/torch/layers.py index ebf94504dd536..e628b11acd34c 100644 --- a/ivy/functional/backends/torch/layers.py +++ b/ivy/functional/backends/torch/layers.py @@ -883,8 +883,6 @@ def lstm( weights_transposed: bool = False, has_ih_bias: bool = True, has_hh_bias: bool = True, - return_sequences: bool = True, - return_states: bool = True, ): if weights_transposed: # transpose the weights if they are in the wrong format @@ -924,15 +922,4 @@ def lstm( batch_first, ) - if return_states: - if return_sequences: - return ret - else: - return tuple( - [ret[0][:, -1], ret[1], ret[2]] - ) # TODO: this depends on batch_first - else: - if return_sequences: - return ret[0] - else: - return ret[0][:, -1] + return ret[0][:, -1], ret[0], (ret[1], ret[2]) diff --git a/ivy/functional/frontends/torch/nn/functional/layer_functions.py b/ivy/functional/frontends/torch/nn/functional/layer_functions.py index 7dfce9e67993f..0a28dd7bf91d0 100644 --- a/ivy/functional/frontends/torch/nn/functional/layer_functions.py +++ b/ivy/functional/frontends/torch/nn/functional/layer_functions.py @@ -26,7 +26,7 @@ def _lstm_full( bidirectional, batch_first, ): - return ivy.lstm( + ret = ivy.lstm( input, hx, params, @@ -38,6 +38,7 @@ def _lstm_full( has_ih_bias=has_biases, has_hh_bias=has_biases, ) + return ret[1], ret[2][0], ret[2][1] def _lstm_packed( @@ -51,7 +52,7 @@ def _lstm_packed( train, bidirectional, ): - return ivy.lstm( + ret = ivy.lstm( data, hx, params, @@ -63,6 +64,7 @@ def _lstm_packed( has_ih_bias=has_biases, has_hh_bias=has_biases, ) + return ret[1], ret[2][0], ret[2][1] # --- Main --- # diff --git a/ivy/functional/ivy/layers.py b/ivy/functional/ivy/layers.py index b9662a0cf1eda..ceabe86041d31 100644 --- a/ivy/functional/ivy/layers.py +++ b/ivy/functional/ivy/layers.py @@ -2397,8 +2397,6 @@ def lstm( weights_transposed: bool = False, has_ih_bias: bool = True, has_hh_bias: bool = True, - return_sequences: bool = True, - return_states: bool = True, ): """Applies a multi-layer long-short term memory to an input sequence. @@ -2442,11 +2440,6 @@ def lstm( whether the `all_weights` argument includes a input-hidden bias has_hh_bias whether the `all_weights` argument includes a hidden-hidden bias - return_sequences - whether to return the last output in the output sequence, - or the full sequence - return_states - whether to return the final hidden and carry states in addition to the output Returns ------- @@ -2567,16 +2560,7 @@ def lstm( if batch_sizes is not None: output = _pack_padded_sequence(output, batch_sizes)[0] - if return_states: - if return_sequences: - return output, h_outs, c_outs - else: - return output[:, -1], h_outs, c_outs # TODO: this depends on batch_first - else: - if return_sequences: - return output - else: - return output[:, -1] + return output[:, -1], output, (h_outs, c_outs) # Helpers # diff --git a/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py b/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py index f18a9b632b028..385b19b1a4c18 100644 --- a/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py +++ b/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py @@ -90,8 +90,6 @@ def _lstm_helper(draw): has_ih_bias = draw(st.booleans()) has_hh_bias = draw(st.booleans()) weights_transposed = draw(st.booleans()) - return_sequences = draw(st.booleans()) - return_states = draw(st.booleans()) bidirectional = draw(st.booleans()) dropout = draw(st.floats(min_value=0, max_value=0.99)) train = draw(st.booleans()) and not dropout @@ -217,8 +215,6 @@ def _lstm_helper(draw): "weights_transposed": weights_transposed, "has_ih_bias": has_ih_bias, "has_hh_bias": has_hh_bias, - "return_sequences": return_sequences, - "return_states": return_states, } else: dtypes = dtype @@ -234,8 +230,6 @@ def _lstm_helper(draw): "weights_transposed": weights_transposed, "has_ih_bias": has_ih_bias, "has_hh_bias": has_hh_bias, - "return_sequences": return_sequences, - "return_states": return_states, } return dtypes, kwargs