You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Im attempting to train a Transformer model for machine translation with a shared vocabulary. As expected the input and target sequences are different lengths. I was expecting Trax to detect and pad the sequences accordingly. I didn't see examples or documentation for this exact problem. Any advice would be greatly appreciated.
# Steps to reproduce:
1. Download a parallel text corpus.
2. Create a vocabulary and tokenize the source and target text and save as TFRecords.
3. Run this following code to train a Transformer model:
# Error logs:
/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Traceback (most recent call last):
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 480, in _forward_abstract
input_signature, weight_signature, self.state, rng)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/math/jax.py", line 175, in shape_fun
jax_shapes = jax.eval_shape(f, *args, **kwargs)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/api.py", line 2104, in eval_shape
*map(abstractify, args_flat))
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 274, in abstract_eval_fun
instantiate=True)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 358, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/linear_util.py", line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/linear_util.py", line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 477, in call_on_input
return self.forward_with_state(x, weights=weights, state=state, rng=rng)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 220, in forward_with_state
return self.forward(inputs, weights), state
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 580, in _forward
raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/attention.py", line 42, in ShiftRight
pad_widths[1] = (n_shifts, 0) # Padding on axis=1
IndexError: list assignment index out of range
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 310, in init
weights, state = self.new_weights_and_state(input_signature)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/combinators.py", line 92, in new_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 485, in _forward_abstract
trace)
trax.layers.base.LayerError: Exception passing through layer ShiftRight (in _forward_abstract):
layer created in file [...]/trax/models/transformer.py, line 291
layer input shapes: ShapeDtype{shape:(1,), dtype:int32}
File [...]/trax/math/jax.py, line 175, in shape_fun
jax_shapes = jax.eval_shape(f, *args, **kwargs)
File [...]/site-packages/jax/api.py, line 2104, in eval_shape
*map(abstractify, args_flat))
File [...]/jax/interpreters/partial_eval.py, line 274, in abstract_eval_fun
instantiate=True)
File [...]/jax/interpreters/partial_eval.py, line 358, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/trax/layers/base.py, line 477, in call_on_input
return self.forward_with_state(x, weights=weights, state=state, rng=rng)
File [...]/trax/layers/base.py, line 220, in forward_with_state
return self.forward(inputs, weights), state
File [...]/trax/layers/base.py, line 580, in _forward
raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access
File [...]/trax/layers/attention.py, line 42, in ShiftRight
pad_widths[1] = (n_shifts, 0) # Padding on axis=1
IndexError: list assignment index out of range
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 310, in init
weights, state = self.new_weights_and_state(input_signature)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/combinators.py", line 91, in new_weights_and_state
weights_or_empty, state = sublayer.init(inputs)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 321, in init
input_signature, trace)
trax.layers.base.LayerError: Exception passing through layer Serial (in init):
layer created in file [...]/trax/models/transformer.py, line 301
layer input shapes: (ShapeDtype{shape:(1,), dtype:int64}, ShapeDtype{shape:(1,), dtype:int64})
File [...]/trax/layers/combinators.py, line 92, in new_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)
LayerError: Exception passing through layer ShiftRight (in _forward_abstract):
layer created in file [...]/trax/models/transformer.py, line 291
layer input shapes: ShapeDtype{shape:(1,), dtype:int32}
File [...]/trax/math/jax.py, line 175, in shape_fun
jax_shapes = jax.eval_shape(f, *args, **kwargs)
File [...]/site-packages/jax/api.py, line 2104, in eval_shape
*map(abstractify, args_flat))
File [...]/jax/interpreters/partial_eval.py, line 274, in abstract_eval_fun
instantiate=True)
File [...]/jax/interpreters/partial_eval.py, line 358, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/trax/layers/base.py, line 477, in call_on_input
return self.forward_with_state(x, weights=weights, state=state, rng=rng)
File [...]/trax/layers/base.py, line 220, in forward_with_state
return self.forward(inputs, weights), state
File [...]/trax/layers/base.py, line 580, in _forward
raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access
File [...]/trax/layers/attention.py, line 42, in ShiftRight
pad_widths[1] = (n_shifts, 0) # Padding on axis=1
IndexError: list assignment index out of range
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "src/machine_translation/trax/main.py", line 97, in <module>
main()
File "src/machine_translation/trax/main.py", line 93, in main
output_dir=os.path.expanduser('~/train_dir/'))
File "src/machine_translation/trax/main.py", line 21, in train_model
output_dir=output_dir
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/supervised/trainer_lib.py", line 217, in __init__
self.reset(output_dir)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/supervised/trainer_lib.py", line 297, in reset
opt_state, model_state = self._new_opt_state_and_model_state()
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/supervised/trainer_lib.py", line 170, in <lambda>
model_target_shape, self._inputs.target_dtype, init_rng))
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/api.py", line 150, in f_jitted
name=flat_fun.__name__)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/core.py", line 895, in call_bind
outs = primitive.impl(f, *args, **params)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/interpreters/xla.py", line 457, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args))
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/linear_util.py", line 220, in memoized_fun
ans = call(fun, *args)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/interpreters/xla.py", line 474, in _xla_callable
fun, pvals, instantiate=False, stage_out_calls=True, bottom=True)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 358, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/linear_util.py", line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/supervised/trainer_lib.py", line 159, in new_opt_state_and_model_state
weights, state = m.init(input_signature)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 321, in init
input_signature, trace)
trax.layers.base.LayerError: Exception passing through layer Serial (in init):
layer created in file [...]/trax/supervised/trainer_lib.py, line 157
layer input shapes: (ShapeDtype{shape:(1,), dtype:int64}, ShapeDtype{shape:(1,), dtype:int64})
File [...]/trax/layers/combinators.py, line 91, in new_weights_and_state
weights_or_empty, state = sublayer.init(inputs)
LayerError: Exception passing through layer Serial (in init):
layer created in file [...]/trax/models/transformer.py, line 301
layer input shapes: (ShapeDtype{shape:(1,), dtype:int64}, ShapeDtype{shape:(1,), dtype:int64})
File [...]/trax/layers/combinators.py, line 92, in new_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)
LayerError: Exception passing through layer ShiftRight (in _forward_abstract):
layer created in file [...]/trax/models/transformer.py, line 291
layer input shapes: ShapeDtype{shape:(1,), dtype:int32}
File [...]/trax/math/jax.py, line 175, in shape_fun
jax_shapes = jax.eval_shape(f, *args, **kwargs)
File [...]/site-packages/jax/api.py, line 2104, in eval_shape
*map(abstractify, args_flat))
File [...]/jax/interpreters/partial_eval.py, line 274, in abstract_eval_fun
instantiate=True)
File [...]/jax/interpreters/partial_eval.py, line 358, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/trax/layers/base.py, line 477, in call_on_input
return self.forward_with_state(x, weights=weights, state=state, rng=rng)
File [...]/trax/layers/base.py, line 220, in forward_with_state
return self.forward(inputs, weights), state
File [...]/trax/layers/base.py, line 580, in _forward
raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access
File [...]/trax/layers/attention.py, line 42, in ShiftRight
pad_widths[1] = (n_shifts, 0) # Padding on axis=1
IndexError: list assignment index out of range
The text was updated successfully, but these errors were encountered:
This is indeed a strange bug - we use different input and target lengths, and it shouldn't be a problem, indeed. This error may indicate something different though: are you sure your inputs and targets are never of length = 0? I think an error like this could happen if length is 0... can that be? (Could you just check when producing them in the generator?)
Description
Im attempting to train a Transformer model for machine translation with a shared vocabulary. As expected the input and target sequences are different lengths. I was expecting Trax to detect and pad the sequences accordingly. I didn't see examples or documentation for this exact problem. Any advice would be greatly appreciated.
Environment information
For bugs: reproduction and error logs
Other output:
The text was updated successfully, but these errors were encountered: