Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

52B example sharding error #41

Closed
deeplearningapps opened this issue May 24, 2023 · 3 comments
Closed

52B example sharding error #41

deeplearningapps opened this issue May 24, 2023 · 3 comments

Comments

@deeplearningapps
Copy link

Hi,

I was trying to run the 1x v4-384 52B model example following MaxText/configs/1xv4-384.sh on a v4-384 slice and hit the following error:

Traceback (most recent call last): 
  File "maxtext/MaxText/train.py", line 334, in <module> 
    app.run(main) 
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 308, in run 
    _run_main(main, args) 
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 254, in _run_main 
    sys.exit(main(argv)) 
  File "maxtext/MaxText/train.py", line 330, in main 
    train_loop(pyconfig.config) 
  File "maxtext/MaxText/train.py", line 277, in train_loop 
    state, state_mesh_annotations = max_utils.setup_initial_state(model, tx, config, init_rng, mesh, checkpoint_manager) 
  File "/home/wx/maxtext/MaxText/max_utils.py", line 159, in setup_initial_state 
    state = pjit( 
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback 
    return fun(*args, **kwargs) 
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/pjit.py", line 208, in cache_miss 
    outs, out_flat, out_tree, args_flat = _python_pjit_helper( 
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/pjit.py", line 150, in _python_pjit_helper 
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn( 
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/pjit.py", line 735, in infer_params 
    return common_infer_params(pjit_info_args, *args, **kwargs) 
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/pjit.py", line 474, in common_infer_params 
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr( 
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/pjit.py", line 937, in _pjit_jaxpr 
    canonicalized_out_shardings_flat = _check_and_canonicalize_out_shardings( 
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/pjit.py", line 920, in _check_and_canonicalize_out_shardings 
    pjit_check_aval_sharding( 
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/pjit.py", line 973, in pjit_check_aval_sharding 
    raise ValueError(f"One of {what_aval}{name_str} was given the sharding " 
jax._src.traceback_util.UnfilteredStackTrace: ValueError: One of pjit outputs with pytree key path .params['decoder']['decoder']['self_attention']['key_layer_norm']['scale'].value was given the sharding of NamedSharding(mesh={'data': 1, 'fsdp': 192, 'tensor': 1}, spec=PartitionSpec('fsdp', None)), which implies that the global size of its dimension 0 should be divisible by 192, but it is equal to 256 (full shape: (256, 32)) 

The stack trace below excludes JAX-internal frames. 
The preceding is the original exception that occurred, unmodified. 
  
-------------------- 
  
The above exception was the direct cause of the following exception: 
  
Traceback (most recent call last): 
  File "maxtext/MaxText/train.py", line 334, in <module> 
    app.run(main) 
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 308, in run 
    _run_main(main, args) 
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 254, in _run_main 
    sys.exit(main(argv)) 
  File "maxtext/MaxText/train.py", line 330, in main 
    train_loop(pyconfig.config) 
  File "maxtext/MaxText/train.py", line 277, in train_loop 
    state, state_mesh_annotations = max_utils.setup_initial_state(model, tx, config, init_rng, mesh, checkpoint_manager) 
  File "/home/wx/maxtext/MaxText/max_utils.py", line 159, in setup_initial_state 
    state = pjit( 
ValueError: One of pjit outputs with pytree key path .params['decoder']['decoder']['self_attention']['key_layer_norm']['scale'].value was given the sharding of NamedSharding(mesh={'data': 1, 'fsdp': 192, 'tensor': 1}, spec=PartitionSpec('fsdp', None)), which implies that the global size of its dimension 0 should be divisible by 192, but it is equal to 256 (full shape: (256, 32))

It looks like this has to do with the sharding spec being incompatible with the tensor shape? Below are the commands I used to set up (used the main branch and jax-0.4.10) and run the experiment, any ideas on what went wrong here?

$ gcloud compute tpus tpu-vm ssh tpuv4 --zone=us-central2-b --worker=all --command="git clone https://github.com/google/maxtext.git" 
$ gcloud compute tpus tpu-vm ssh tpuv4 --zone=us-central2-b --worker=all --command="cd maxtext; sudo bash setup.sh" 
$ gcloud compute tpus tpu-vm ssh tpuv4 --zone=us-central2-b --worker=all --command="export LIBTPU_INIT_ARGS='--xla_enable_async_all_gather=true TPU_MEGACORE=MEGACORE_DENSE'" 
$ gcloud compute tpus tpu-vm ssh tpuv4 --zone=us-central2-b --worker=all --command="python3 maxtext/MaxText/train.py maxtext/MaxText/configs/base.yml run_name=max_52B base_output_directory=gs://wx/max/ dataset_path=gs://maxtext_dt/ enable_profiler=true enable_checkpointing=false steps=10 ici_fsdp_parallelism=192 ici_tensor_parallelism=1 scale=4 base_num_decoder_layers=8 per_device_batch_size=10 remat_policy=full base_emb_dim=3072 base_mlp_dim=12288 learning_rate=1e-8" 
@rwitten
Copy link
Collaborator

rwitten commented May 24, 2023

Thanks for reporting this and sorry for the poor experience. Unfortunately, we don't have access to a v4-384 right now so we don't have integration tests running and I can't get you timing data.

The issue was caused by adding KV layernorm to improve numerical stability.

I have put together a fix here, can you let me know if you're unblocked?
https://github.com/google/maxtext/pull/42/files

@rwitten
Copy link
Collaborator

rwitten commented May 30, 2023

Hi - I will close this out in a week (6/5) if I don't hear back.

@deeplearningapps
Copy link
Author

Hi, I was able to run with the provided fix. Thanks!

A9isha pushed a commit that referenced this issue Apr 11, 2024
Change-Id: I2f50ac40c89f2f16a0601e75f608b7cb4428643a
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants