Skip to content

Conversation

cosmo3769
Copy link
Contributor

Hi @mattdangerw @ariG23498,

Ported bart transformers checkpoint in kerasNLP. Please check. Thank you!

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Just minor comments. Unlike mistral I think we can land with the test against the bart base version (though a tiny checkpoint would for tests would still be great).

Needs merge conflict resolution because of albert I think.

class TestTask(TestCase):
@pytest.mark.large
def test_convert_tiny_preset(self):
model = BartSeq2SeqLM.from_preset("hf://facebook/bart-base")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

less urgent than the 7b mistral model, but we might want to consider making a tiny test checkpoint for this model too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added tiny-bart-test

@cosmo3769 cosmo3769 requested a review from mattdangerw August 21, 2024 17:26
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome work! thank you!

@mattdangerw
Copy link
Member

(i'll merge as soon as testing finishes if all is green)

@mattdangerw mattdangerw merged commit 5299cd4 into keras-team:master Aug 21, 2024
6 checks passed
@mattdangerw
Copy link
Member

Oops, I forgot to run our large testing for this. This is causing test failures, see below....

This is because we are not setting some weight (with shape [1024,32]) in the model from the safetensors checkpoint. So we probably "missed a spot" in the weight port.

@cosmo3769 can you take a look?

=================================== FAILURES ===================================
______________________ TestTask.test_convert_tiny_preset _______________________
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

self = 

    @pytest.mark.large
    def test_convert_tiny_preset(self):
        model = BartSeq2SeqLM.from_preset("hf://cosmo3769/tiny-bart-test")
        prompt = "What is your favorite condiment?"
>       model.generate([prompt], max_length=15)

keras_nlp/src/utils/transformers/convert_bart_test.py:25: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
keras_nlp/src/models/causal_lm.py:360: in generate
    outputs = [generate(x) for x in inputs]
keras_nlp/src/models/causal_lm.py:360: in 
    outputs = [generate(x) for x in inputs]
keras_nlp/src/models/causal_lm.py:350: in generate
    return generate_function(x, stop_token_ids=stop_token_ids)
keras_nlp/src/models/causal_lm.py:209: in wrapped_generate_function
    outputs, sampler_variables = compiled_generate_function(
/tmpfs/venv/lib/python3.9/site-packages/jax/_src/array.py:855: in _array_shard_arg
    x._check_if_deleted()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <[RuntimeError('Array has been deleted with shape=float32[1024,32].') raised in repr()] ArrayImpl object at 0x1756c8510>

    def _check_if_deleted(self):
      if self.is_deleted():
>       raise RuntimeError(
            f"Array has been deleted with shape={self.aval.str_short()}.")
E       RuntimeError: Array has been deleted with shape=float32[1024,32].

/tmpfs/venv/lib/python3.9/site-packages/jax/_src/array.py:553: RuntimeError
------------------------------ Captured log call -------------------------------
WARNING  tensorflow:polymorphic_function.py:157 5 out of the last 13 calls to > triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.

---------- coverage: platform linux, python 3.9.19-final-0 -----------

=========================== short test summary info ============================

@cosmo3769
Copy link
Contributor Author

@mattdangerw, I looked into it. I find that:

keras: path=encoder_position_embedding/embeddings> has shape (1024, 32).
hf: model.decoder.embed_positions.weight & model.encoder.embed_positions.weight has shape (1026, 32)

It will cause shape mismatch.

@cosmo3769
Copy link
Contributor Author

@mattdangerw, I looked into it. I find that:

keras: path=encoder_position_embedding/embeddings> has shape (1024, 32). hf: model.decoder.embed_positions.weight & model.encoder.embed_positions.weight has shape (1026, 32)

It will cause shape mismatch.

Well, slicing the first two elements works in the case:

port_weight(
    keras_variable=keras_backbone.encoder_position_embedding.position_embeddings,
    hf_weight_key="encoder.embed_positions.weight",
    hook_fn=lambda hf_tensor, keras_shape: np.reshape(
                hf_tensor[2:1026, :], keras_shape
            ),
)

@mattdangerw
Copy link
Member

@cosmo3769 thanks for the quick reply!

Is there an extra two embedding positions being added in the HF implementation for position embeddings? Why? Might be worth looking into.

To unbreak CI, sounds like we can just just do this?

    hook_fn=lambda hf_tensor, keras_shape: np.reshape(hf_tensor[2:, :], keras_shape),

Doing a small CI to turn testing green again would be helpful as we dig deeper here. I wouldn't hardcode the "1024" part if we can avoid it.

@cosmo3769
Copy link
Contributor Author

Yeah, I hardcoded the above value for reference point. For now, let me raise a quick MR with the changes [2:, :] to solve the CI issue. Then, I can look into more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants