-
Notifications
You must be signed in to change notification settings - Fork 290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sweep across KV cache layouts #662
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you Andy for the great work! Overall looks good, and I am happy to see the first pass results.
Some comments/suggestions:
I think my overall goal is try to get rid of MaxText/inference_microbenchmark_sweep.py
and let MaxText/inference_microbenchmark.py
being self contained.
On the ml_auto_solutions side, any sweeping now or later can either use existing flags(base.yml) or we may need to introduce new flags as part of the experiment. Have a manual test run, then scale up for more experiments. It'd be great if there is no extra / minimum code requirement between the manual test, and ml_auto_solutions.
2c3ecf8
to
185563d
Compare
187cb3d
to
000e935
Compare
LGTM on my side! Adding @patemotter for visibility since he may need to use this soon. |
58b5b31
to
4b4eaaa
Compare
4b4eaaa
to
57429da
Compare
# Manually update the config | ||
# Don't set key_value_axis_order_product_id; otherwise it will recompute | ||
# ar_key_axis_order and ar_value_axis_order | ||
quant = 'bf16' if not config.quantization else config.quantization | ||
run_name = ( | ||
f"{inference_metadata['accelerator']}-{config.model_name}-" | ||
f"{quant}-{key_value_axis_order_product_id}-{prefill_key_axis_order}-" | ||
f"{ar_key_axis_order}" | ||
) | ||
tensorboard_dir = os.path.join(config.base_output_directory, run_name, "tensorboard", "") | ||
checkpoint_dir = os.path.join(config.base_output_directory, run_name, "checkpoint", "") | ||
metrics_dir = os.path.join(config.base_output_directory, run_name, "metrics", "") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are quant
and quantize_kvcache
and different combination of these two, as discussed, we will create different test_config in xlml, and the base_run_name should already have all the information to differentiate the runs
pyconfig._config.keys['tensorboard_dir'] = tensorboard_dir # pylint: disable=protected-access | ||
pyconfig._config.keys['checkpoint_dir'] = checkpoint_dir # pylint: disable=protected-access | ||
pyconfig._config.keys['metrics_dir'] = metrics_dir # pylint: disable=protected-access |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think checkpoint_dir and metrics_dir are used at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please take a look
@morgandu Can you take a final look? Anything else we need to add? |
Final LGTM! Though the PR description need to be updated! since we have prefill_cache_axis_order and ar_cache_axis_order now. |
Updated description. |
074cf22
to
6c03e98
Compare
6c03e98
to
9606e62
Compare
Sweep across different sharding configurations for KV cache. Will be used in our automation infra here GoogleCloudPlatform/ml-auto-solutions#288
User needs to set the config's inference_metadata_file, which is a path to a
json file.
This json should contain the following keys:
be true