Skip to content

Conversation

mattdangerw
Copy link
Member

@mattdangerw mattdangerw commented Aug 15, 2024

Old behavior:

  • On TF backend, raggeds and strings returned as tf tensors.
  • On Jax/Torch backends, raggeds and strings returned as lists.
  • Preprocessing functions outside of __call__, like tokenize(), detokenize(), generate_preprocess(), will always return tf tensors on all backends.

This made it hard to write backend agnostic code. TF shows up in random places, and if you are flipping from tf -> jax or vice versa you have to switch between handling tensors and lists.

New behavior:

  • On all backends for all preprocessing functions, raggeds and strings are returned as lists.
  • Inside a tf.data call or tf compiled function, preprocessing layers always output tf.tensors.

This requires a little complexity to avoid over converting back and forth from tf -> python in nested calls, but thankfully we can hide most of that complexity in a decorator.

@mattdangerw mattdangerw force-pushed the consistent-preprocessing-outputs branch 7 times, most recently from 2008b84 to 5de16cd Compare August 16, 2024 03:04
@mattdangerw mattdangerw changed the title [DRAFT] Consistent preprocessing output on all backends Consistent preprocessing output on all backends Aug 16, 2024
@mattdangerw mattdangerw marked this pull request as ready for review August 16, 2024 03:05
Old behavior:
- On TF backend, raggeds and strings returned as tf tensors.
- On Jax/Torch bakcnes, raggeds and strings returned as lists.
- Preprocessing functions outside of `call`, like `tokenize()`,
  `detokenize()`, `generate_preprocess()`, will always return
  tf tensors.

This made it hard to write backend agnostic code. TF shows up in
random places, and if you are flipping from tf -> jax or vice versa
you have to switch between handling tensors and lists.

New behavior:
- On all backends for all functions, raggeds and strings are returned
  as lists.
- Inside a `tf.data` call or tf compiled function, preprocessing layers
  always output tf.tensors.

This requires a little complexity to avoid over converting back and
forth for tf -> python, but thankfully we can hide most of that
complexity in a decorator.
@mattdangerw mattdangerw force-pushed the consistent-preprocessing-outputs branch from 5de16cd to 5fdaa4a Compare August 16, 2024 03:13
@mattdangerw mattdangerw force-pushed the consistent-preprocessing-outputs branch from c582212 to 331f6a1 Compare August 16, 2024 20:56
Copy link
Member

@SamanehSaadat SamanehSaadat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks, Matt! Just left a couple of nit comments!

@mattdangerw mattdangerw added the kokoro:force-run Runs Tests on GPU label Aug 19, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Aug 19, 2024
@mattdangerw mattdangerw merged commit 180c7ec into keras-team:master Aug 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants