Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Usage of Params Consistent v2 #568

Merged
merged 1 commit into from
Apr 3, 2024

Conversation

anfals
Copy link
Collaborator

@anfals anfals commented Mar 28, 2024

#498 Showed that we have inconsistent use of params vs {'params': params} within MaxText, which causes problems when we have changes that require multiple collections of variables (like https://github.com/google/flax/blob/main/flax/linen/fp8_ops.py).

These changes streamline this as much as possible to use params directly.

Changes from Last Time:
The last version of this PR caused issues because it didn't update the checkpoint conversion scripts. Those have now been updated to reflected the updated checkpoint structure (which is basically just a 'params' wrapping the outside of the dict).

@anfals anfals requested a review from rwitten as a code owner March 28, 2024 22:35
@anfals anfals force-pushed the anfals/params_streamline_redo branch from f03988e to d2e4572 Compare March 29, 2024 21:13
@anfals anfals changed the title [NOT READY FOR MERGE] Make Usage of Params Consistent Make Usage of Params Consistent v2 Mar 29, 2024
@anfals anfals requested a review from gobbleturk March 29, 2024 21:17
Copy link
Collaborator

@rwitten rwitten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One important random change slipped in to a test by accident.

Because this is now a high risk change, I'd like the XLML tests run in advance. Can you run them all via XLML?

(And separately, can you post in the MaxText room what you're planning to do here and that it is breaking compat?)

@@ -178,8 +178,11 @@ def map_to_pspec(data):
p = epath.Path(load_parameters_from_path)
ckptr = orbax.checkpoint.PyTreeCheckpointer()
restore_args = orbax.checkpoint.checkpoint_utils.construct_restore_args(abstract_unboxed_pre_state.params)
# Orbax quirk -> we save the entire TrainState, which has a field `params` which holds the PyTree
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm this seems weird, should we fix that while we're fixing things?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack, I'll double check with Susie who worked on this last I think. If I had to answer now, I think just saving the state directly is useful because we also get the optimizer state captured in the checkpoint. And we can cleanly capture any other variable collections (like overwrite_with_gradients) without any sort of special logic.

Copy link
Collaborator Author

@anfals anfals Apr 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: We can actually just tweak this a little bit to just load things used the abstract_state directly and it works. Cleaner code, but I wonder if this is less efficient/wasteful because we are restoring things we don't need (i.e. all the stuff that isn't params)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And another update: This actually seems to be an optimization. We don't want to restore anything other than the params field here as that would be wasteful, both computationally and for memory usage (we immediately toss the rest of it). So we have this little odd bit here to basically say only restore the params field. I added a comment clarifying what's really happening.

@@ -13,6 +13,7 @@ DATASET_PATH=gs://maxtext-dataset

export LOSS_THRESHOLD=2.5

echo python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_direct_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${base_ckpt_path} model_name='llama2-7b' dataset_path=${DATASET_PATH} async_checkpointing=false model_name='llama2-7b' ici_tensor_parallelism=4 steps=10 per_device_batch_size=.25 metrics_file='metrics.txt'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^^^

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops good catch. fixed!

@rwitten rwitten removed their assignment Mar 30, 2024
@anfals anfals force-pushed the anfals/params_streamline_redo branch from ba8863f to 504c1cd Compare April 2, 2024 17:19
@anfals anfals force-pushed the anfals/params_streamline_redo branch from 2b5c6f7 to 6498e14 Compare April 2, 2024 18:14
@anfals
Copy link
Collaborator Author

anfals commented Apr 2, 2024

One important random change slipped in to a test by accident.

Because this is now a high risk change, I'd like the XLML tests run in advance. Can you run them all via XLML?

(And separately, can you post in the MaxText room what you're planning to do here and that it is breaking compat?)

Made some fixes and added some clarifying comments!

I've run the XLML tests and they passed, but I'm running them again with the latest gemma test changes. I won't merge until that's finished. I'll also post in the MaxText room before I merge.

@anfals anfals assigned anfals and rwitten and unassigned anfals Apr 2, 2024
@anfals anfals requested a review from rwitten April 2, 2024 20:07
@rwitten rwitten removed their assignment Apr 3, 2024
@copybara-service copybara-service bot merged commit de49d83 into main Apr 3, 2024
9 checks passed
@copybara-service copybara-service bot deleted the anfals/params_streamline_redo branch April 3, 2024 18:45
@A9isha A9isha mentioned this pull request Apr 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants