Skip to content

nn.GRU skips bhn at the first timestep when hidden=None #3249

@ssmall256

Description

@ssmall256

Summary

mlx.nn.GRU is not equivalent to an explicit zero initial state when called with the default hidden=None.

At the first timestep, the implementation skips the hidden-side contribution to the new gate entirely, including the learned bhn bias. That makes:

gru(x)

behave differently from:

gru(x, hidden=mx.zeros((batch_size, hidden_size)))

for the same weights and input.

Current status

I first noticed this on MLX 0.31.1 and confirmed it is still present on current main at:

  • b0564a91123ab79bd3bdbe1a251d3492a316349c

This appears to be present since v0.7.0, the first release that included nn.GRU.

Why this looks wrong

The GRU update in the docstring is:

n_t = tanh(W_xn x_t + b_n + r_t ⊙ (W_hn h_t + b_hn))

If hidden=None means an implicit zero initial state, then at t=0 this should reduce to:

n = tanh(W_xn x + b_n + r * b_hn)

But the current code skips the hidden-path work when hidden is None, so the first step becomes:

n = tanh(W_xn x + b_n)

which drops the r * bhn term.

Repro

import numpy as np
import mlx.core as mx
import mlx.nn as nn

np.random.seed(42)
mx.random.seed(42)

D, H = 4, 3
gru = nn.GRU(D, H, bias=True)
x = mx.array(np.random.randn(1, 3, D).astype(np.float32) * 0.1)

y_none = gru(x)
y_zeros = gru(x, hidden=mx.zeros((1, H)))

mx.eval(y_none, y_zeros)

print("hidden=None: ", np.array(y_none)[0, 0])
print("hidden=zeros:", np.array(y_zeros)[0, 0])
print("max diff:    ", np.abs(np.array(y_none) - np.array(y_zeros)).max())

Observed on current main:

hidden=None:  [ 0.19130191  0.03415391 -0.12855242]
hidden=zeros: [ 0.12526409  0.06978775 -0.23501271]
max diff:     0.10646029

With bias=False, the mismatch disappears.

PyTorch comparison

Using the same weights and input, PyTorch matches the explicit-zero MLX path, not the hidden=None MLX path.

Observed max diffs on the validated checkout:

max_diff(PyTorch, MLX hidden=zeros) = 5.96e-08
max_diff(PyTorch, MLX hidden=None)  = 1.0646035e-01

Impact

This affects nn.GRU when:

  • bias=True
  • hidden is omitted or passed as None

That is likely the common call pattern. The first hidden state is wrong, and the error then propagates through the rest of the sequence.

Test gap

The current GRU coverage in python/tests/test_nn.py appears to be shape-only and does not check that gru(x) matches gru(x, hidden=zeros).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions