Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sweep across KV cache layouts #662

Open
wants to merge 77 commits into
base: main
Choose a base branch
from

Conversation

yeandy
Copy link
Collaborator

@yeandy yeandy commented May 21, 2024

Sweep across different sharding configurations for KV cache.

User needs to set the config's inference_metadata_file, which is a path to a
json file.

This json should contain the following keys:

  • key_value_axis_order_product_id_list: comma separated string of key_value_axis_order_product_id
  • prefill_key_axis_order_list: comma delimited string of prefill_key_axis_order
  • prefill_value_axis_order_list: comma delimited string of prefill_value_axis_order
  • ar_key_axis_order_list: comma delimited string of ar_key_axis_order
  • ar_value_axis_order_list: comma delimited string of ar_value_axis_order
  • accelerator: name of the accelerator
  • flatten_microbenchmark_results: Whether or not to flatten results. Should
    be true

MaxText/configs/base.yml Outdated Show resolved Hide resolved
MaxText/configs/base.yml Outdated Show resolved Hide resolved
Copy link
Collaborator

@morgandu morgandu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you Andy for the great work! Overall looks good, and I am happy to see the first pass results.

Some comments/suggestions:

I think my overall goal is try to get rid of MaxText/inference_microbenchmark_sweep.py and let MaxText/inference_microbenchmark.py being self contained.

On the ml_auto_solutions side, any sweeping now or later can either use existing flags(base.yml) or we may need to introduce new flags as part of the experiment. Have a manual test run, then scale up for more experiments. It'd be great if there is no extra / minimum code requirement between the manual test, and ml_auto_solutions.

Comment on lines 77 to 88
# Manually update the config
# Don't set key_value_axis_order_product_id; otherwise it will recompute
# ar_key_axis_order and ar_value_axis_order
quant = 'bf16' if not config.quantization else config.quantization
run_name = (
f"{inference_metadata['accelerator']}-{config.model_name}-"
f"{quant}-{key_value_axis_order_product_id}-{prefill_key_axis_order}-"
f"{ar_key_axis_order}"
)
tensorboard_dir = os.path.join(config.base_output_directory, run_name, "tensorboard", "")
checkpoint_dir = os.path.join(config.base_output_directory, run_name, "checkpoint", "")
metrics_dir = os.path.join(config.base_output_directory, run_name, "metrics", "")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are quant and quantize_kvcache and different combination of these two, as discussed, we will create different test_config in xlml, and the base_run_name should already have all the information to differentiate the runs

Comment on lines 93 to 95
pyconfig._config.keys['tensorboard_dir'] = tensorboard_dir # pylint: disable=protected-access
pyconfig._config.keys['checkpoint_dir'] = checkpoint_dir # pylint: disable=protected-access
pyconfig._config.keys['metrics_dir'] = metrics_dir # pylint: disable=protected-access
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think checkpoint_dir and metrics_dir are used at all?

Copy link
Collaborator Author

@yeandy yeandy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take a look

MaxText/inference_microbenchmark.py Show resolved Hide resolved
MaxText/inference_microbenchmark.py Show resolved Hide resolved
pyconfig._config.keys['ar_value_axis_order'] = ar_value_axis_order # pylint: disable=protected-access
pyconfig._config.keys['tensorboard_dir'] = tensorboard_dir # pylint: disable=protected-access
pyconfig._config.keys['run_name'] = run_name # pylint: disable=protected-access
max_utils.write_config_raw_keys_for_gcs(pyconfig._config.keys) # pylint: disable=protected-access
Copy link
Collaborator

@morgandu morgandu Jun 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

@yeandy yeandy Jun 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that write_config_raw_keys_for_gcs is called once during initialize (https://github.com/google/maxtext/blob/main/MaxText/pyconfig.py#L225), and writes the default values of prefill and ar axis orders to GCS. And so each time we loop for a different prefill/ar axis, we need to explicitly call max_utils.write_config_raw_keys_for_gcs again to make sure we write the updated values to GCS.

Did you take this out in your code, and it worked for you?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@morgandu should I do a test without max_utils.write_config_raw_keys_for_gcs?

@yeandy yeandy marked this pull request as ready for review June 17, 2024 16:52
@yeandy
Copy link
Collaborator Author

yeandy commented Jun 17, 2024

@morgandu Can you take a final look? Anything else we need to add?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants