Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IndexError: list assignment index out of range in Transformer model #403

Closed
nate-bush opened this issue Mar 27, 2020 · 3 comments
Closed

Comments

@nate-bush
Copy link

Description

Im attempting to train a Transformer model for machine translation with a shared vocabulary. As expected the input and target sequences are different lengths. I was expecting Trax to detect and pad the sequences accordingly. I didn't see examples or documentation for this exact problem. Any advice would be greatly appreciated.

Environment information

OS: MacOS 10.14.6 (18G3020)

$ pip freeze | grep tensor
mesh-tensorflow==0.1.12
tensor2tensor==1.15.4
tensorboard==2.1.1
tensorflow==2.1.0
tensorflow-datasets==2.1.0
tensorflow-estimator==2.1.0
tensorflow-gan==2.0.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.1
tensorflow-probability==0.7.0

$ pip freeze | grep jax
jax==0.1.62
jaxlib==0.1.42

$ python -V
Python 3.7.6

For bugs: reproduction and error logs

# Steps to reproduce:
1. Download a parallel text corpus. 
2. Create a vocabulary and tokenize the source and target text and save as TFRecords.
3. Run this following code to train a Transformer model:
import os

import tensorflow as tf
import trax

from src.common.params import Paths


def train_model(inputs: trax.supervised.Inputs, model_function, output_dir):

  trainer = trax.supervised.Trainer(
    model=model_function,
    loss_fn=trax.layers.CrossEntropyLoss,
    optimizer=trax.optimizers.Adafactor,
    lr_schedule=trax.lr.MultifactorSchedule,
    inputs=inputs,
    output_dir=output_dir
  )

  n_epochs = 10
  train_steps = 10
  eval_steps = 1
  for _ in range(n_epochs):
    trainer.train_epoch(train_steps, eval_steps)


def parse_example(serialized_example):
  """Return inputs and targets Tensors from a serialized tf.Example."""
  data_fields = {
      "inputs": tf.io.VarLenFeature(tf.int64),
      "targets": tf.io.VarLenFeature(tf.int64)
  }
  parsed = tf.io.parse_single_example(serialized_example, data_fields)
  inputs = tf.sparse.to_dense(parsed["inputs"])
  targets = tf.sparse.to_dense(parsed["targets"])
  return inputs, targets


def file_length(filename):
  with open(filename) as f:
    for i, l in enumerate(f):
      pass
  return i + 1


def main():

  ML_ROOT = os.path.join(Paths.data_root, 'machine_translation')
  trax_path = os.path.join(ML_ROOT, 'trax')
  tokenizer_path = os.path.join(trax_path, 'subtoken.vocab')
  tokenized_records_path = os.path.join(trax_path, 'tokenized')
  os.makedirs(trax_path, exist_ok=True)
  os.makedirs(tokenized_records_path, exist_ok=True)

  tf_record_filenames = [os.path.join(tokenized_records_path, p) for p in
                         os.listdir(tokenized_records_path)]

  dataset = tf.data.TFRecordDataset(tf_record_filenames).map(parse_example)

  inputs = trax.supervised.Inputs(
    train_stream=lambda _: dataset.as_numpy_iterator(),
    eval_stream=lambda _: dataset.as_numpy_iterator()
  )

  # Peek into the inputs.
  data_stream = inputs.train_stream(n_devices=1)
  for _ in range(10):
    sample_input, sample_target = next(data_stream)
    print('-' * 100)
    print("Inputs:  %s, len: %s" % (str(sample_input), str(len(sample_input))))
    print("Targets: %s, len: %s" % (str(sample_target), str(len(sample_target))))

  vocab_size = file_length(tokenizer_path)
  print('Vocab size:', vocab_size)

  def transformer(mode):
    return trax.models.Transformer(
      vocab_size,
      mode=mode
    )

  train_model(inputs, model_function=transformer, output_dir=os.path.expanduser('~/train_dir/'))


if __name__ == '__main__':
  main()

Other output:

----------------------------------------------------------------------------------------------------
Inputs:  [704 656  32 769   2 588 820 936   2  47   4   1], len: 12
Targets: [946 947 950 942 937 462   2   7   5 238  21 377 336   4   1], len: 15
----------------------------------------------------------------------------------------------------
Inputs:  [798 128 221 866   2 249  37 471 912   4   1], len: 11
Targets: [946 947 950 944 937   2 338 188 190 383 301 106   4   1], len: 14
----------------------------------------------------------------------------------------------------
Inputs:  [ 55 641  34 425 685  53 426 391 356 426   4   1], len: 12
Targets: [946 947 948 953 937 216  10  46 230 104 196 172 364 187 187  21   4   1], len: 18
----------------------------------------------------------------------------------------------------
Inputs:  [167 475  45  20 122 139  48   2  56 148  76  56 246  33  53 424 299 209
 220 687   2   4   1], len: 23
Targets: [988 240  64 127 719  29  92 346 380 109 206 292 163 378   5  67   4   1], len: 18
----------------------------------------------------------------------------------------------------
Inputs:  [ 66 156  84 687  43 246  56 687   2  40 246  81 687 739 258  56 148  30
 687  50 156  83  54 246  56 687   2 697 156  30 687  45   4   1], len: 34
Targets: [946 947 950 944 937   2  71 698 980 932 827 827   2  13 322   4   1], len: 17
----------------------------------------------------------------------------------------------------
Inputs:  [129 665 256 145 470   4   1], len: 7
Targets: [946 947 950 944 937   2 802   5  12   4   1], len: 11
----------------------------------------------------------------------------------------------------
Inputs:  [225 255 861  85 148  45  28  56 148  50 687  95 246  62 148 165   2 431
  51 181 246 393 739  86  66 148 432  41   4   1], len: 30
Targets: [946 947 950 944 937   2 123 380 304  77  11 745  16  67 153 709 126 618
 462   2   4   1], len: 22
----------------------------------------------------------------------------------------------------
Inputs:  [438 118  95 128 903 162  69 282 239   4   1], len: 11
Targets: [946 947 950 944 937  10  18 311 394 404 311 778 119 238  64  27   4   1], len: 18
----------------------------------------------------------------------------------------------------
Inputs:  [101 602 195 310  37  19   1], len: 7
Targets: [946 947 942 951 937 261  12 231 261  17   1], len: 11
----------------------------------------------------------------------------------------------------
Inputs:  [162 253 481 128  78 141 161 145   4   1], len: 10
Targets: [946 947 950 944 937  10  91 213 357   5 106   6 410 189 613  10   4   1], len: 18
Vocab size: 992
# Error logs:
/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
Traceback (most recent call last):
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 480, in _forward_abstract
    input_signature, weight_signature, self.state, rng)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/math/jax.py", line 175, in shape_fun
    jax_shapes = jax.eval_shape(f, *args, **kwargs)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/api.py", line 2104, in eval_shape
    *map(abstractify, args_flat))
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 274, in abstract_eval_fun
    instantiate=True)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 358, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/linear_util.py", line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/linear_util.py", line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 477, in call_on_input
    return self.forward_with_state(x, weights=weights, state=state, rng=rng)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 220, in forward_with_state
    return self.forward(inputs, weights), state
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 580, in _forward
    raw_output = raw_fn(x, weights=weights, **self._kwargs)  # pylint: disable=protected-access
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/attention.py", line 42, in ShiftRight
    pad_widths[1] = (n_shifts, 0)  # Padding on axis=1
IndexError: list assignment index out of range

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 310, in init
    weights, state = self.new_weights_and_state(input_signature)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/combinators.py", line 92, in new_weights_and_state
    outputs, _ = sublayer._forward_abstract(inputs)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 485, in _forward_abstract
    trace)
trax.layers.base.LayerError: Exception passing through layer ShiftRight (in _forward_abstract):
  layer created in file [...]/trax/models/transformer.py, line 291
  layer input shapes: ShapeDtype{shape:(1,), dtype:int32}

  File [...]/trax/math/jax.py, line 175, in shape_fun
    jax_shapes = jax.eval_shape(f, *args, **kwargs)

  File [...]/site-packages/jax/api.py, line 2104, in eval_shape
    *map(abstractify, args_flat))

  File [...]/jax/interpreters/partial_eval.py, line 274, in abstract_eval_fun
    instantiate=True)

  File [...]/jax/interpreters/partial_eval.py, line 358, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)

  File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/trax/layers/base.py, line 477, in call_on_input
    return self.forward_with_state(x, weights=weights, state=state, rng=rng)

  File [...]/trax/layers/base.py, line 220, in forward_with_state
    return self.forward(inputs, weights), state

  File [...]/trax/layers/base.py, line 580, in _forward
    raw_output = raw_fn(x, weights=weights, **self._kwargs)  # pylint: disable=protected-access

  File [...]/trax/layers/attention.py, line 42, in ShiftRight
    pad_widths[1] = (n_shifts, 0)  # Padding on axis=1

IndexError: list assignment index out of range

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 310, in init
    weights, state = self.new_weights_and_state(input_signature)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/combinators.py", line 91, in new_weights_and_state
    weights_or_empty, state = sublayer.init(inputs)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 321, in init
    input_signature, trace)
trax.layers.base.LayerError: Exception passing through layer Serial (in init):
  layer created in file [...]/trax/models/transformer.py, line 301
  layer input shapes: (ShapeDtype{shape:(1,), dtype:int64}, ShapeDtype{shape:(1,), dtype:int64})

  File [...]/trax/layers/combinators.py, line 92, in new_weights_and_state
    outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer ShiftRight (in _forward_abstract):
  layer created in file [...]/trax/models/transformer.py, line 291
  layer input shapes: ShapeDtype{shape:(1,), dtype:int32}

  File [...]/trax/math/jax.py, line 175, in shape_fun
    jax_shapes = jax.eval_shape(f, *args, **kwargs)

  File [...]/site-packages/jax/api.py, line 2104, in eval_shape
    *map(abstractify, args_flat))

  File [...]/jax/interpreters/partial_eval.py, line 274, in abstract_eval_fun
    instantiate=True)

  File [...]/jax/interpreters/partial_eval.py, line 358, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)

  File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/trax/layers/base.py, line 477, in call_on_input
    return self.forward_with_state(x, weights=weights, state=state, rng=rng)

  File [...]/trax/layers/base.py, line 220, in forward_with_state
    return self.forward(inputs, weights), state

  File [...]/trax/layers/base.py, line 580, in _forward
    raw_output = raw_fn(x, weights=weights, **self._kwargs)  # pylint: disable=protected-access

  File [...]/trax/layers/attention.py, line 42, in ShiftRight
    pad_widths[1] = (n_shifts, 0)  # Padding on axis=1

IndexError: list assignment index out of range

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "src/machine_translation/trax/main.py", line 97, in <module>
    main()
  File "src/machine_translation/trax/main.py", line 93, in main
    output_dir=os.path.expanduser('~/train_dir/'))
  File "src/machine_translation/trax/main.py", line 21, in train_model
    output_dir=output_dir
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/supervised/trainer_lib.py", line 217, in __init__
    self.reset(output_dir)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/supervised/trainer_lib.py", line 297, in reset
    opt_state, model_state = self._new_opt_state_and_model_state()
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/supervised/trainer_lib.py", line 170, in <lambda>
    model_target_shape, self._inputs.target_dtype, init_rng))
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/api.py", line 150, in f_jitted
    name=flat_fun.__name__)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/core.py", line 895, in call_bind
    outs = primitive.impl(f, *args, **params)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/interpreters/xla.py", line 457, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args))
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/linear_util.py", line 220, in memoized_fun
    ans = call(fun, *args)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/interpreters/xla.py", line 474, in _xla_callable
    fun, pvals, instantiate=False, stage_out_calls=True, bottom=True)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 358, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/linear_util.py", line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/supervised/trainer_lib.py", line 159, in new_opt_state_and_model_state
    weights, state = m.init(input_signature)
  File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 321, in init
    input_signature, trace)
trax.layers.base.LayerError: Exception passing through layer Serial (in init):
  layer created in file [...]/trax/supervised/trainer_lib.py, line 157
  layer input shapes: (ShapeDtype{shape:(1,), dtype:int64}, ShapeDtype{shape:(1,), dtype:int64})

  File [...]/trax/layers/combinators.py, line 91, in new_weights_and_state
    weights_or_empty, state = sublayer.init(inputs)

LayerError: Exception passing through layer Serial (in init):
  layer created in file [...]/trax/models/transformer.py, line 301
  layer input shapes: (ShapeDtype{shape:(1,), dtype:int64}, ShapeDtype{shape:(1,), dtype:int64})

  File [...]/trax/layers/combinators.py, line 92, in new_weights_and_state
    outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer ShiftRight (in _forward_abstract):
  layer created in file [...]/trax/models/transformer.py, line 291
  layer input shapes: ShapeDtype{shape:(1,), dtype:int32}

  File [...]/trax/math/jax.py, line 175, in shape_fun
    jax_shapes = jax.eval_shape(f, *args, **kwargs)

  File [...]/site-packages/jax/api.py, line 2104, in eval_shape
    *map(abstractify, args_flat))

  File [...]/jax/interpreters/partial_eval.py, line 274, in abstract_eval_fun
    instantiate=True)

  File [...]/jax/interpreters/partial_eval.py, line 358, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)

  File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/trax/layers/base.py, line 477, in call_on_input
    return self.forward_with_state(x, weights=weights, state=state, rng=rng)

  File [...]/trax/layers/base.py, line 220, in forward_with_state
    return self.forward(inputs, weights), state

  File [...]/trax/layers/base.py, line 580, in _forward
    raw_output = raw_fn(x, weights=weights, **self._kwargs)  # pylint: disable=protected-access

  File [...]/trax/layers/attention.py, line 42, in ShiftRight
    pad_widths[1] = (n_shifts, 0)  # Padding on axis=1

IndexError: list assignment index out of range
@YannickWehr
Copy link

I am having the same problem, however I only encounter the problem in Google Colab, and not when running identical code on my own machine.

@lukaszkaiser
Copy link
Contributor

This is indeed a strange bug - we use different input and target lengths, and it shouldn't be a problem, indeed. This error may indicate something different though: are you sure your inputs and targets are never of length = 0? I think an error like this could happen if length is 0... can that be? (Could you just check when producing them in the generator?)

@lukaszkaiser
Copy link
Contributor

Any updates on that? I'm tending towards closing this soon as a lot has changed in the recent versions, but let us know if the problem persists!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants