You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, I'm facing a problem while trying to work with Trax Trainer class. I have loaded my dataset from TFRecords file and created a Dataset instance using Dataset API. Then, I try to feed my dataset to the Trax trainer, but got this error. Could you please tell me how to accomplish this? I haven't found anything explaining how to use Dataset API with Trax library. Thanks!
In Trax, the inputs are basically just python streams, not tf.Data objects. (The reason that this is possible without a big speed penalty is that JAX allows to dispatch the training step to GPU/TPU asynchronously from the python interpreter, so the python overhead is amortized by the training step.)
So to put a tf.Data into Trax, you need to convert it back to python. This is generally done by the function trax.math.dataset_to_numpy but a few tweaks may be needed to make tuples out of dicts or thing like that. In Trax, we do it here: https://github.com/google/trax/blob/master/trax/supervised/inputs.py#L366
Description
Hello, I'm facing a problem while trying to work with Trax Trainer class. I have loaded my dataset from TFRecords file and created a Dataset instance using Dataset API. Then, I try to feed my dataset to the Trax trainer, but got this error. Could you please tell me how to accomplish this? I haven't found anything explaining how to use Dataset API with Trax library. Thanks!
Environment information
OS: Google Colab notebook
For bugs: reproduction and error logs
Steps to reproduce:
Pass dataset iterator to tras.Inputs class
Error logs:
TypeError: Argument '[[ 2 16 9 ... 0 0 0]
[ 2 16 9 ... 0 0 0]
[ 2 16 9 ... 0 0 0]
...
[ 2 70 21 ... 0 0 0]
[ 2 16 9 ... 0 0 0]
[ 2 47 14 ... 0 0 0]]' of type <class 'tensorflow.python.framework.ops.EagerTensor'> is not a valid JAX type
The text was updated successfully, but these errors were encountered: