Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

does Functional Model support "string" dtype for end-to-end pipelines? #18410

Open
Mrutyunjay01 opened this issue Sep 9, 2023 · 6 comments
Open
Assignees
Labels
backend:jax type:support User is asking for help / asking an implementation question. Stackoverflow would be better suited.

Comments

@Mrutyunjay01
Copy link
Contributor

As mentioned keras-team/keras-core#843, I was trying further modify the mlm-bert pipeline for keras-core with backend-agnostic support. The last part of the example tries to create an end-to-end pipeline with raw texts as the input to the model. Mentioned as follows:

def get_end_to_end(model):
    inputs_string = keras.Input(shape=(1,), dtype="string")
    indices = vectorize_layer(inputs_string)
    outputs = model(indices)
    end_to_end_model = keras.Model(inputs_string, outputs, name="end_to_end_model")
    optimizer = keras.optimizers.Adam(learning_rate=config.LR)
    end_to_end_model.compile(
        optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"]
    )
    return end_to_end_model


end_to_end_classification_model = get_end_to_end(classifer_model)
end_to_end_classification_model.evaluate(test_raw_classifier_ds)

But while executing with jax backend, it throws the following error:

ValueError                                Traceback (most recent call last)
Cell In[10], line 14
     10     return end_to_end_model
     13 end_to_end_classification_model = get_end_to_end(classifer_model)
---> 14 end_to_end_classification_model.evaluate(test_raw_classifier_ds)

File [~/Documents/oss/tfjax/keras-core/kscore/lib/python3.10/site-packages/keras_core/src/utils/traceback_utils.py:123](https://file+.vscode-resource.vscode-cdn.net/Users/mrutyunjaybiswal/Documents/oss/tfjax/keras-core-examples/keras-core-port-nb/nlp/~/Documents/oss/tfjax/keras-core/kscore/lib/python3.10/site-packages/keras_core/src/utils/traceback_utils.py:123), in filter_traceback..error_handler(*args, **kwargs)
    120     filtered_tb = _process_traceback_frames(e.__traceback__)
    121     # To get the full stack trace, call:
    122     # `keras_core.config.disable_traceback_filtering()`
--> 123     raise e.with_traceback(filtered_tb) from None
    124 finally:
    125     del filtered_tb

File [~/Documents/oss/tfjax/keras-core/kscore/lib/python3.10/site-packages/tree/__init__.py:435](https://file+.vscode-resource.vscode-cdn.net/Users/mrutyunjaybiswal/Documents/oss/tfjax/keras-core-examples/keras-core-port-nb/nlp/~/Documents/oss/tfjax/keras-core/kscore/lib/python3.10/site-packages/tree/__init__.py:435), in map_structure(func, *structures, **kwargs)
    432 for other in structures[1:]:
    433   assert_same_structure(structures[0], other, check_types=check_types)
    434 return unflatten_as(structures[0],
--> 435                     [func(*args) for args in zip(*map(flatten, structures))])

File [~/Documents/oss/tfjax/keras-core/kscore/lib/python3.10/site-packages/tree/__init__.py:435](https://file+.vscode-resource.vscode-cdn.net/Users/mrutyunjaybiswal/Documents/oss/tfjax/keras-core-examples/keras-core-port-nb/nlp/~/Documents/oss/tfjax/keras-core/kscore/lib/python3.10/site-packages/tree/__init__.py:435), in (.0)
    432 for other in structures[1:]:
    433   assert_same_structure(structures[0], other, check_types=check_types)
    434 return unflatten_as(structures[0],
--> 435                     [func(*args) for args in zip(*map(flatten, structures))])

ValueError: Invalid dtype: object

Is having the input layer of dtype `string' might cause the problem here?

@fchollet
Copy link
Member

The dtype string is only supported with the TF backend. So there's another issue here, which is that the error message is not clear. It should just say that dtype string isn't supported with the JAX backend.

@SuryanarayanaY SuryanarayanaY added the type:support User is asking for help / asking an implementation question. Stackoverflow would be better suited. label Sep 11, 2023
@fchollet fchollet transferred this issue from keras-team/keras-core Sep 22, 2023
@sachinprasadhs sachinprasadhs self-assigned this May 1, 2024
@sachinprasadhs
Copy link
Collaborator

@Mrutyunjay01, Could you please provide sample reproducible code, the tutorial you have mentioned has other error to solve before we get to the reported issue part.

@Mrutyunjay01
Copy link
Contributor Author

Apologies for late response. I faced the issue when tried to port BERT MLM to backend agnostic. Been a while, let me refine the port code once, and see if the issue persists still.

cc: @sachinprasadhs

@Mrutyunjay01
Copy link
Contributor Author

Update:

The issue no longer persists in keras 3.3. So, we can close this. Currently I am trying to port BERT MLM to backend agnostic, will report if I come across any such issues.

Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@Mrutyunjay01
Copy link
Contributor Author

Reopening the issue as faced in the draft PR mentioned.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:jax type:support User is asking for help / asking an implementation question. Stackoverflow would be better suited.
Projects
None yet
Development

No branches or pull requests

4 participants