-
Notifications
You must be signed in to change notification settings - Fork 254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make Usage of Params Consistent v2 #568
Conversation
f03988e
to
d2e4572
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One important random change slipped in to a test by accident.
Because this is now a high risk change, I'd like the XLML tests run in advance. Can you run them all via XLML?
(And separately, can you post in the MaxText room what you're planning to do here and that it is breaking compat?)
MaxText/checkpointing.py
Outdated
@@ -178,8 +178,11 @@ def map_to_pspec(data): | |||
p = epath.Path(load_parameters_from_path) | |||
ckptr = orbax.checkpoint.PyTreeCheckpointer() | |||
restore_args = orbax.checkpoint.checkpoint_utils.construct_restore_args(abstract_unboxed_pre_state.params) | |||
# Orbax quirk -> we save the entire TrainState, which has a field `params` which holds the PyTree |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm this seems weird, should we fix that while we're fixing things?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ack, I'll double check with Susie who worked on this last I think. If I had to answer now, I think just saving the state directly is useful because we also get the optimizer state captured in the checkpoint. And we can cleanly capture any other variable collections (like overwrite_with_gradients
) without any sort of special logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update: We can actually just tweak this a little bit to just load things used the abstract_state directly and it works. Cleaner code, but I wonder if this is less efficient/wasteful because we are restoring things we don't need (i.e. all the stuff that isn't params
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And another update: This actually seems to be an optimization. We don't want to restore anything other than the params
field here as that would be wasteful, both computationally and for memory usage (we immediately toss the rest of it). So we have this little odd bit here to basically say only restore the params field. I added a comment clarifying what's really happening.
end_to_end/llama_finetuning_test.sh
Outdated
@@ -13,6 +13,7 @@ DATASET_PATH=gs://maxtext-dataset | |||
|
|||
export LOSS_THRESHOLD=2.5 | |||
|
|||
echo python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_direct_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${base_ckpt_path} model_name='llama2-7b' dataset_path=${DATASET_PATH} async_checkpointing=false model_name='llama2-7b' ici_tensor_parallelism=4 steps=10 per_device_batch_size=.25 metrics_file='metrics.txt' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
^^^
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whoops good catch. fixed!
ba8863f
to
504c1cd
Compare
2b5c6f7
to
6498e14
Compare
Made some fixes and added some clarifying comments! I've run the XLML tests and they passed, but I'm running them again with the latest gemma test changes. I won't merge until that's finished. I'll also post in the MaxText room before I merge. |
#498 Showed that we have inconsistent use of params vs {'params': params} within MaxText, which causes problems when we have changes that require multiple collections of variables (like https://github.com/google/flax/blob/main/flax/linen/fp8_ops.py).
These changes streamline this as much as possible to use params directly.
Changes from Last Time:
The last version of this PR caused issues because it didn't update the checkpoint conversion scripts. Those have now been updated to reflected the updated checkpoint structure (which is basically just a 'params' wrapping the outside of the dict).