Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FlaxGPTJ #14396

Merged
merged 9 commits into from
Dec 1, 2021
Merged

FlaxGPTJ #14396

merged 9 commits into from
Dec 1, 2021

Conversation

patil-suraj
Copy link
Contributor

What does this PR do?

This PR adds the GPTJ model in flax.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me. At some point it will be nice to check the GPU JAX tests, I believe they're timing out right now.

# overwrite from common since `attention_mask` in combination
# with `causal_mask` behaves slighly differently
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this test - could you also add it in test_modeling_gptj.py so that the test fetcher always correctly checks PT<->Flax (e.g. at the moment if PT GPTJ is changed the test fetcher will not run this test, thus we need an identical test in the PyTorch test suite)

Copy link
Contributor Author

@patil-suraj patil-suraj Nov 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are already added in test_modeling_common.py, so they run for all pt models now

def test_equivalence_flax_to_pt(self):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok! So for the Flax files the PT<=>Flax tests need to be overwritten whereas for the PT files the test doesn't have to be overwritten? If this is the case - the PR is good to be merged for me :-)

@patrickvonplaten
Copy link
Contributor

Looks great - just one small test for PT<>Flax compatibility should be added in test_modeling_gptj as well :-)

@patil-suraj patil-suraj merged commit 4c0dd19 into huggingface:master Dec 1, 2021
@patil-suraj patil-suraj deleted the flax-gptj branch December 1, 2021 05:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants