I am getting the following error when fine-tuning longT5 model:
`
ValueError Traceback (most recent call last)
Input In [16], in <cell line: 21>()
14 gin_utils.parse_gin_flags(
15 # User-provided gin paths take precedence if relative paths conflict.
16 FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS,
17 FLAGS.gin_file,
18 FLAGS.gin_bindings)
19 train_using_gin()
---> 21 gin_utils.run(main_train)
File ~/Downloads/t5x/t5x/gin_utils.py:105, in run(main)
103 def run(main):
104 """Wrapper for app.run that rewrites gin args before parsing."""
--> 105 app.run(
106 main,
107 flags_parser=lambda a: app.parse_flags_with_usage(rewrite_gin_args(a)))
File ~/opt/miniconda3/lib/python3.9/site-packages/absl/app.py:312, in run(main, argv, flags_parser)
310 callback()
311 try:
--> 312 _run_main(main, args)
313 except UsageError as error:
314 usage(shorthelp=True, detailed_error=error, exitcode=error.exitcode)
File ~/opt/miniconda3/lib/python3.9/site-packages/absl/app.py:258, in _run_main(main, argv)
256 sys.exit(retval)
257 else:
--> 258 sys.exit(main(argv))
Input In [15], in main_train(argv)
1 def main_train(argv: Sequence[str]):
2 """Wrapper for pdb post mortems."""
----> 3 _main(argv)
Input In [16], in _main(argv)
12 train_using_gin = gin.configurable(train)
14 gin_utils.parse_gin_flags(
15 # User-provided gin paths take precedence if relative paths conflict.
16 FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS,
17 FLAGS.gin_file,
18 FLAGS.gin_bindings)
---> 19 train_using_gin()
File ~/opt/miniconda3/lib/python3.9/site-packages/gin/config.py:1605, in _make_gin_wrapper..gin_wrapper(*args, **kwargs)
1603 scope_info = " in scope '{}'".format(scope_str) if scope_str else ''
1604 err_str = err_str.format(name, fn_or_cls, scope_info)
-> 1605 utils.augment_exception_message_and_reraise(e, err_str)
File ~/opt/miniconda3/lib/python3.9/site-packages/gin/utils.py:41, in augment_exception_message_and_reraise(exception, message)
39 proxy = ExceptionProxy()
40 ExceptionProxy.qualname = type(exception).qualname
---> 41 raise proxy.with_traceback(exception.traceback) from None
File ~/opt/miniconda3/lib/python3.9/site-packages/gin/config.py:1582, in _make_gin_wrapper..gin_wrapper(*args, **kwargs)
1579 new_kwargs.update(kwargs)
1581 try:
-> 1582 return fn(*new_args, **new_kwargs)
1583 except Exception as e: # pylint: disable=broad-except
1584 err_str = ''
Input In [7], in train(model, train_dataset_cfg, train_eval_dataset_cfg, infer_eval_dataset_cfg, checkpoint_cfg, partitioner, trainer_cls, model_dir, total_steps, eval_steps, eval_period, stats_period, random_seed, use_hardware_rng, summarize_config_fn, inference_evaluator_cls, get_dataset_fn, concurrent_metrics, actions, train_eval_get_dataset_fn, run_eval_before_training, use_gda)
224 input_types = {
225 k: v.dtype.as_numpy_dtype() for k, v in train_ds.element_spec.items()
226 }
227 init_or_restore_tick = time.time()
--> 228 train_state_initializer = utils.TrainStateInitializer(
229 optimizer_def=model.optimizer_def,
230 init_fn=model.get_initial_variables,
231 input_shapes=input_shapes,
232 input_types=input_types,
233 partitioner=partitioner)
234 # 3. From scratch using init_fn.
235 train_state = train_state_initializer.from_checkpoint_or_scratch(
236 restore_cfgs, init_rng=init_rng, ds_iter=checkpointable_train_iter)
File ~/Downloads/t5x/t5x/utils.py:368, in TrainStateInitializer.init(self, optimizer_def, init_fn, input_shapes, partitioner, input_types)
365 self._partitioner = partitioner
366 self.global_train_state_shape = jax.eval_shape(
367 initialize_train_state, rng=jax.random.PRNGKey(0))
--> 368 self.train_state_axes = partitioner.get_mesh_axes(
369 self.global_train_state_shape)
370 self._initialize_train_state = initialize_train_state
372 # Currently scanned layers require passing annotations through to the
373 # point of the scan transformation to resolve an XLA SPMD issue.
374
375 # init_fn is always(?) equal to model.get_initial_variables, fetch the model
376 # instance from the bound method.
File ~/Downloads/t5x/t5x/partitioning.py:892, in PjitPartitioner.get_mesh_axes(self, train_state)
888 raise ValueError(f'Failed to map logical axes for {param_name}') from e
890 flat_logical_axes = traverse_util.flatten_dict(
891 logical_axes.state_dict(), keep_empty_nodes=True, sep='/')
--> 892 flat_mesh_axes = {
893 k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items()
894 }
896 return logical_axes.restore_state(
897 traverse_util.unflatten_dict(flat_mesh_axes, sep='/'))
File ~/Downloads/t5x/t5x/partitioning.py:893, in (.0)
888 raise ValueError(f'Failed to map logical axes for {param_name}') from e
890 flat_logical_axes = traverse_util.flatten_dict(
891 logical_axes.state_dict(), keep_empty_nodes=True, sep='/')
892 flat_mesh_axes = {
--> 893 k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items()
894 }
896 return logical_axes.restore_state(
897 traverse_util.unflatten_dict(flat_mesh_axes, sep='/'))
File ~/Downloads/t5x/t5x/partitioning.py:888, in PjitPartitioner.get_mesh_axes.._logical_to_mesh_axes(param_name, logical_axes)
885 return flax_partitioning.logical_to_mesh_axes(logical_axes,
886 self._logical_axis_rules)
887 except ValueError as e:
--> 888 raise ValueError(f'Failed to map logical axes for {param_name}') from e
ValueError: Failed to map logical axes for target/decoder/logits_dense/kernel
In call to configurable 'train' (<function train at 0x2b751e160>)
`
I am getting the following error when fine-tuning longT5 model:
`
ValueError Traceback (most recent call last)
Input In [16], in <cell line: 21>()
14 gin_utils.parse_gin_flags(
15 # User-provided gin paths take precedence if relative paths conflict.
16 FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS,
17 FLAGS.gin_file,
18 FLAGS.gin_bindings)
19 train_using_gin()
---> 21 gin_utils.run(main_train)
File ~/Downloads/t5x/t5x/gin_utils.py:105, in run(main)
103 def run(main):
104 """Wrapper for app.run that rewrites gin args before parsing."""
--> 105 app.run(
106 main,
107 flags_parser=lambda a: app.parse_flags_with_usage(rewrite_gin_args(a)))
File ~/opt/miniconda3/lib/python3.9/site-packages/absl/app.py:312, in run(main, argv, flags_parser)
310 callback()
311 try:
--> 312 _run_main(main, args)
313 except UsageError as error:
314 usage(shorthelp=True, detailed_error=error, exitcode=error.exitcode)
File ~/opt/miniconda3/lib/python3.9/site-packages/absl/app.py:258, in _run_main(main, argv)
256 sys.exit(retval)
257 else:
--> 258 sys.exit(main(argv))
Input In [15], in main_train(argv)
1 def main_train(argv: Sequence[str]):
2 """Wrapper for pdb post mortems."""
----> 3 _main(argv)
Input In [16], in _main(argv)
12 train_using_gin = gin.configurable(train)
14 gin_utils.parse_gin_flags(
15 # User-provided gin paths take precedence if relative paths conflict.
16 FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS,
17 FLAGS.gin_file,
18 FLAGS.gin_bindings)
---> 19 train_using_gin()
File ~/opt/miniconda3/lib/python3.9/site-packages/gin/config.py:1605, in _make_gin_wrapper..gin_wrapper(*args, **kwargs)
1603 scope_info = " in scope '{}'".format(scope_str) if scope_str else ''
1604 err_str = err_str.format(name, fn_or_cls, scope_info)
-> 1605 utils.augment_exception_message_and_reraise(e, err_str)
File ~/opt/miniconda3/lib/python3.9/site-packages/gin/utils.py:41, in augment_exception_message_and_reraise(exception, message)
39 proxy = ExceptionProxy()
40 ExceptionProxy.qualname = type(exception).qualname
---> 41 raise proxy.with_traceback(exception.traceback) from None
File ~/opt/miniconda3/lib/python3.9/site-packages/gin/config.py:1582, in _make_gin_wrapper..gin_wrapper(*args, **kwargs)
1579 new_kwargs.update(kwargs)
1581 try:
-> 1582 return fn(*new_args, **new_kwargs)
1583 except Exception as e: # pylint: disable=broad-except
1584 err_str = ''
Input In [7], in train(model, train_dataset_cfg, train_eval_dataset_cfg, infer_eval_dataset_cfg, checkpoint_cfg, partitioner, trainer_cls, model_dir, total_steps, eval_steps, eval_period, stats_period, random_seed, use_hardware_rng, summarize_config_fn, inference_evaluator_cls, get_dataset_fn, concurrent_metrics, actions, train_eval_get_dataset_fn, run_eval_before_training, use_gda)
224 input_types = {
225 k: v.dtype.as_numpy_dtype() for k, v in train_ds.element_spec.items()
226 }
227 init_or_restore_tick = time.time()
--> 228 train_state_initializer = utils.TrainStateInitializer(
229 optimizer_def=model.optimizer_def,
230 init_fn=model.get_initial_variables,
231 input_shapes=input_shapes,
232 input_types=input_types,
233 partitioner=partitioner)
234 # 3. From scratch using
init_fn.235 train_state = train_state_initializer.from_checkpoint_or_scratch(
236 restore_cfgs, init_rng=init_rng, ds_iter=checkpointable_train_iter)
File ~/Downloads/t5x/t5x/utils.py:368, in TrainStateInitializer.init(self, optimizer_def, init_fn, input_shapes, partitioner, input_types)
365 self._partitioner = partitioner
366 self.global_train_state_shape = jax.eval_shape(
367 initialize_train_state, rng=jax.random.PRNGKey(0))
--> 368 self.train_state_axes = partitioner.get_mesh_axes(
369 self.global_train_state_shape)
370 self._initialize_train_state = initialize_train_state
372 # Currently scanned layers require passing annotations through to the
373 # point of the scan transformation to resolve an XLA SPMD issue.
374
375 # init_fn is always(?) equal to model.get_initial_variables, fetch the model
376 # instance from the bound method.
File ~/Downloads/t5x/t5x/partitioning.py:892, in PjitPartitioner.get_mesh_axes(self, train_state)
888 raise ValueError(f'Failed to map logical axes for {param_name}') from e
890 flat_logical_axes = traverse_util.flatten_dict(
891 logical_axes.state_dict(), keep_empty_nodes=True, sep='/')
--> 892 flat_mesh_axes = {
893 k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items()
894 }
896 return logical_axes.restore_state(
897 traverse_util.unflatten_dict(flat_mesh_axes, sep='/'))
File ~/Downloads/t5x/t5x/partitioning.py:893, in (.0)
888 raise ValueError(f'Failed to map logical axes for {param_name}') from e
890 flat_logical_axes = traverse_util.flatten_dict(
891 logical_axes.state_dict(), keep_empty_nodes=True, sep='/')
892 flat_mesh_axes = {
--> 893 k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items()
894 }
896 return logical_axes.restore_state(
897 traverse_util.unflatten_dict(flat_mesh_axes, sep='/'))
File ~/Downloads/t5x/t5x/partitioning.py:888, in PjitPartitioner.get_mesh_axes.._logical_to_mesh_axes(param_name, logical_axes)
885 return flax_partitioning.logical_to_mesh_axes(logical_axes,
886 self._logical_axis_rules)
887 except ValueError as e:
--> 888 raise ValueError(f'Failed to map logical axes for {param_name}') from e
ValueError: Failed to map logical axes for target/decoder/logits_dense/kernel
In call to configurable 'train' (<function train at 0x2b751e160>)
`