diff --git a/python/mlx/nn/layers/recurrent.py b/python/mlx/nn/layers/recurrent.py index 3ffa7654c3..a5d31fd8c2 100644 --- a/python/mlx/nn/layers/recurrent.py +++ b/python/mlx/nn/layers/recurrent.py @@ -184,6 +184,8 @@ def __call__(self, x, hidden=None): if hidden is not None: n = n + r * h_proj_n + elif self.bhn is not None: + n = n + r * self.bhn n = mx.tanh(n) if hidden is not None: diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 174823f179..5957184b19 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -1941,6 +1941,14 @@ def test_gru(self): h_out = layer(inp, h_out[-1, :]) self.assertEqual(h_out.shape, (44, 12)) + # hidden=None should be equivalent to hidden=zeros (issue #3249) + for bias in [True, False]: + layer = nn.GRU(5, 12, bias=bias) + inp = mx.random.normal((2, 25, 5)) + h_none = layer(inp) + h_zeros = layer(inp, hidden=mx.zeros((2, 12))) + self.assertTrue(mx.allclose(h_none, h_zeros).item()) + def test_lstm(self): layer = nn.LSTM(5, 12) inp = mx.random.normal((2, 25, 5))