Skip to content

Fix GRU hidden=None handling for biased new-gate path#3250

Closed
ssmall256 wants to merge 1 commit intoml-explore:mainfrom
ssmall256:fix-gru-hidden-none-bias
Closed

Fix GRU hidden=None handling for biased new-gate path#3250
ssmall256 wants to merge 1 commit intoml-explore:mainfrom
ssmall256:fix-gru-hidden-none-bias

Conversation

@ssmall256
Copy link
Copy Markdown

Proposed changes

Fix nn.GRU so hidden=None is treated the same as an explicit zero initial state.

Previously, the hidden=None path skipped the hidden-side new-gate contribution at the first timestep, including bhn, so gru(x) could differ from gru(x, hidden=zeros) when bias=True.

This change initializes a zero hidden state before the loop and adds a regression test covering batched and unbatched inputs.

Closes #3249

Tests

PYTHONPATH=python pytest python/tests/test_nn.py

Checklist

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@ssmall256
Copy link
Copy Markdown
Author

ssmall256 commented Mar 13, 2026

I checked prior MLX issues and PRs for duplicates before opening this. I did not find an existing report or fix for this specific hidden=None / first-step bhn behavior. The closest prior GRU-related change I found was #952, but that patch addressed a different GRU bug in the hidden-state update and did not change the hidden=None first-timestep path.

@angeloskath
Copy link
Copy Markdown
Member

Hi @ssmall256 , thanks for the fix! I hope it's ok but I will merge #3252 instead because it is a bit more concise and efficient. The efficiency comes from avoiding all the computations with 0s when None is passed.

Thanks again for the fix!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

nn.GRU skips bhn at the first timestep when hidden=None

2 participants