Summary
mlx.nn.GRU is not equivalent to an explicit zero initial state when called with the default hidden=None.
At the first timestep, the implementation skips the hidden-side contribution to the new gate entirely, including the learned bhn bias. That makes:
behave differently from:
gru(x, hidden=mx.zeros((batch_size, hidden_size)))
for the same weights and input.
Current status
I first noticed this on MLX 0.31.1 and confirmed it is still present on current main at:
b0564a91123ab79bd3bdbe1a251d3492a316349c
This appears to be present since v0.7.0, the first release that included nn.GRU.
Why this looks wrong
The GRU update in the docstring is:
n_t = tanh(W_xn x_t + b_n + r_t ⊙ (W_hn h_t + b_hn))
If hidden=None means an implicit zero initial state, then at t=0 this should reduce to:
n = tanh(W_xn x + b_n + r * b_hn)
But the current code skips the hidden-path work when hidden is None, so the first step becomes:
which drops the r * bhn term.
Repro
import numpy as np
import mlx.core as mx
import mlx.nn as nn
np.random.seed(42)
mx.random.seed(42)
D, H = 4, 3
gru = nn.GRU(D, H, bias=True)
x = mx.array(np.random.randn(1, 3, D).astype(np.float32) * 0.1)
y_none = gru(x)
y_zeros = gru(x, hidden=mx.zeros((1, H)))
mx.eval(y_none, y_zeros)
print("hidden=None: ", np.array(y_none)[0, 0])
print("hidden=zeros:", np.array(y_zeros)[0, 0])
print("max diff: ", np.abs(np.array(y_none) - np.array(y_zeros)).max())
Observed on current main:
hidden=None: [ 0.19130191 0.03415391 -0.12855242]
hidden=zeros: [ 0.12526409 0.06978775 -0.23501271]
max diff: 0.10646029
With bias=False, the mismatch disappears.
PyTorch comparison
Using the same weights and input, PyTorch matches the explicit-zero MLX path, not the hidden=None MLX path.
Observed max diffs on the validated checkout:
max_diff(PyTorch, MLX hidden=zeros) = 5.96e-08
max_diff(PyTorch, MLX hidden=None) = 1.0646035e-01
Impact
This affects nn.GRU when:
bias=True
hidden is omitted or passed as None
That is likely the common call pattern. The first hidden state is wrong, and the error then propagates through the rest of the sequence.
Test gap
The current GRU coverage in python/tests/test_nn.py appears to be shape-only and does not check that gru(x) matches gru(x, hidden=zeros).
Summary
mlx.nn.GRUis not equivalent to an explicit zero initial state when called with the defaulthidden=None.At the first timestep, the implementation skips the hidden-side contribution to the new gate entirely, including the learned
bhnbias. That makes:behave differently from:
for the same weights and input.
Current status
I first noticed this on MLX
0.31.1and confirmed it is still present on currentmainat:b0564a91123ab79bd3bdbe1a251d3492a316349cThis appears to be present since
v0.7.0, the first release that includednn.GRU.Why this looks wrong
The GRU update in the docstring is:
If
hidden=Nonemeans an implicit zero initial state, then att=0this should reduce to:But the current code skips the hidden-path work when
hidden is None, so the first step becomes:which drops the
r * bhnterm.Repro
Observed on current
main:With
bias=False, the mismatch disappears.PyTorch comparison
Using the same weights and input, PyTorch matches the explicit-zero MLX path, not the
hidden=NoneMLX path.Observed max diffs on the validated checkout:
Impact
This affects
nn.GRUwhen:bias=Truehiddenis omitted or passed asNoneThat is likely the common call pattern. The first hidden state is wrong, and the error then propagates through the rest of the sequence.
Test gap
The current GRU coverage in
python/tests/test_nn.pyappears to be shape-only and does not check thatgru(x)matchesgru(x, hidden=zeros).