Skip to content

Port Jax AI stack minigpt example#5405

Merged
copybara-service[bot] merged 1 commit into
google:mainfrom
samanklesaria:minigpt
Apr 29, 2026
Merged

Port Jax AI stack minigpt example#5405
copybara-service[bot] merged 1 commit into
google:mainfrom
samanklesaria:minigpt

Conversation

@samanklesaria
Copy link
Copy Markdown
Collaborator

This PR copies the Jax AI stack example https://docs.jaxstack.ai/en/latest/JAX_for_LLM_pretraining.html to the flax examples directory unchanged. The nnx apis used still seem to be appropriate, so no adjustments are required.

@review-notebook-app
Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Comment thread docs_nnx/examples/minigpt.md Outdated
Comment thread docs_nnx/examples/minigpt.md Outdated
Comment thread docs_nnx/examples/minigpt.md Outdated
Comment thread docs_nnx/examples/minigpt.md Outdated
@samanklesaria samanklesaria marked this pull request as draft April 13, 2026 21:18
@samanklesaria
Copy link
Copy Markdown
Collaborator Author

Actually, it seems we need to do a little more to make this work with the current NNX. Currently getting some sharding errors. Will debug.

@samanklesaria samanklesaria marked this pull request as ready for review April 14, 2026 15:25
@samanklesaria
Copy link
Copy Markdown
Collaborator Author

Sharding issues resolved. Have to wait until I have access to multiple gpus to run the full notebook though.

@samanklesaria
Copy link
Copy Markdown
Collaborator Author

I also exposed the is_causal argument from the dot_product_attention function to MultiHeadAttention. It seems like the user is otherwise expected to generate the causal mask themselves.

@cgarciae
Copy link
Copy Markdown
Collaborator

Looks good but there are some merge issues.

@copybara-service copybara-service Bot merged commit 56b4f38 into google:main Apr 29, 2026
21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants