-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sweep across KV cache layouts #662
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you Andy for the great work! Overall looks good, and I am happy to see the first pass results.
Some comments/suggestions:
I think my overall goal is try to get rid of MaxText/inference_microbenchmark_sweep.py
and let MaxText/inference_microbenchmark.py
being self contained.
On the ml_auto_solutions side, any sweeping now or later can either use existing flags(base.yml) or we may need to introduce new flags as part of the experiment. Have a manual test run, then scale up for more experiments. It'd be great if there is no extra / minimum code requirement between the manual test, and ml_auto_solutions.
# Manually update the config | ||
# Don't set key_value_axis_order_product_id; otherwise it will recompute | ||
# ar_key_axis_order and ar_value_axis_order | ||
quant = 'bf16' if not config.quantization else config.quantization | ||
run_name = ( | ||
f"{inference_metadata['accelerator']}-{config.model_name}-" | ||
f"{quant}-{key_value_axis_order_product_id}-{prefill_key_axis_order}-" | ||
f"{ar_key_axis_order}" | ||
) | ||
tensorboard_dir = os.path.join(config.base_output_directory, run_name, "tensorboard", "") | ||
checkpoint_dir = os.path.join(config.base_output_directory, run_name, "checkpoint", "") | ||
metrics_dir = os.path.join(config.base_output_directory, run_name, "metrics", "") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are quant
and quantize_kvcache
and different combination of these two, as discussed, we will create different test_config in xlml, and the base_run_name should already have all the information to differentiate the runs
pyconfig._config.keys['tensorboard_dir'] = tensorboard_dir # pylint: disable=protected-access | ||
pyconfig._config.keys['checkpoint_dir'] = checkpoint_dir # pylint: disable=protected-access | ||
pyconfig._config.keys['metrics_dir'] = metrics_dir # pylint: disable=protected-access |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think checkpoint_dir and metrics_dir are used at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please take a look
pyconfig._config.keys['ar_value_axis_order'] = ar_value_axis_order # pylint: disable=protected-access | ||
pyconfig._config.keys['tensorboard_dir'] = tensorboard_dir # pylint: disable=protected-access | ||
pyconfig._config.keys['run_name'] = run_name # pylint: disable=protected-access | ||
max_utils.write_config_raw_keys_for_gcs(pyconfig._config.keys) # pylint: disable=protected-access |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can skip this since there is already
--save_config_to_gcs=True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is that write_config_raw_keys_for_gcs
is called once during initialize
(https://github.com/google/maxtext/blob/main/MaxText/pyconfig.py#L225), and writes the default values of prefill and ar axis orders to GCS. And so each time we loop for a different prefill/ar axis, we need to explicitly call max_utils.write_config_raw_keys_for_gcs
again to make sure we write the updated values to GCS.
Did you take this out in your code, and it worked for you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@morgandu should I do a test without max_utils.write_config_raw_keys_for_gcs
?
@morgandu Can you take a final look? Anything else we need to add? |
Sweep across different sharding configurations for KV cache.
User needs to set the config's inference_metadata_file, which is a path to a
json file.
This json should contain the following keys:
be true