Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sweep across KV cache layouts #662

Open
wants to merge 77 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 61 commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
91f4398
layout
morgandu May 16, 2024
18ba546
Add flag to flatten
yeandy May 20, 2024
e78faee
Update config
yeandy May 20, 2024
6c04467
tmp
yeandy May 20, 2024
b839d63
Fix delimiter
yeandy May 20, 2024
baaeda7
Fix delimiter
yeandy May 20, 2024
f0588ee
Fix delimiter
yeandy May 20, 2024
d85990b
Fix
yeandy May 21, 2024
cbb90bf
Fix
yeandy May 21, 2024
929ad8d
Fix
yeandy May 21, 2024
f3cbd3e
Fix write
yeandy May 21, 2024
0a6de55
Fix
yeandy May 21, 2024
f907515
Fix
yeandy May 21, 2024
8333659
Fix
yeandy May 21, 2024
8c0707d
Fix
yeandy May 21, 2024
674ba29
Fix
yeandy May 21, 2024
95f846f
Fix
yeandy May 21, 2024
858f4e3
Catch OOM
yeandy May 21, 2024
1c20d61
Fix
yeandy May 22, 2024
a9b381a
Update tensorboard dir
yeandy May 22, 2024
ec1d76e
add prefill
yeandy May 22, 2024
63edcc2
Fix
yeandy May 22, 2024
268ee64
Fix
yeandy May 22, 2024
c1490f6
Fix
yeandy May 22, 2024
0fe1d51
Fix
yeandy May 22, 2024
554d014
Fix
yeandy May 22, 2024
6458a89
Fix
yeandy May 22, 2024
76bda5d
Fix
yeandy May 22, 2024
42bd412
Fix
yeandy May 22, 2024
df5c66e
Fix
yeandy May 23, 2024
c44da0c
Fix
yeandy May 23, 2024
d84f4c9
Fix
yeandy May 23, 2024
185563d
Add layout control
morgandu May 22, 2024
4291c4c
test
yeandy May 23, 2024
c638c63
test
yeandy May 23, 2024
128a691
test
yeandy May 23, 2024
8387cd7
test
yeandy May 23, 2024
b422007
test
yeandy May 24, 2024
f69f405
test
yeandy May 24, 2024
c37ad18
test
yeandy May 24, 2024
4b3be09
test
yeandy May 24, 2024
65c2289
test
yeandy May 24, 2024
d5637e4
merge
yeandy May 24, 2024
9c0d6a5
Fix
yeandy May 24, 2024
58405b1
Serialize config json
yeandy May 24, 2024
187cb3d
Add kv cache layout control and tests
morgandu May 30, 2024
958938c
Fix string concat
yeandy May 30, 2024
e37e495
Merge mor--kv-cache-layout
yeandy May 30, 2024
a7fb24b
Remove duplicates
yeandy May 30, 2024
cb2db18
Rename enable_profiler
yeandy May 30, 2024
7311225
json default val
yeandy May 30, 2024
000e935
Enable kv cache layout control
morgandu May 30, 2024
ceda588
Fix failed tests
yeandy May 30, 2024
69bfc92
Add instructions
yeandy May 30, 2024
ca107c5
merge
yeandy May 30, 2024
1d417e4
Address comments
yeandy May 31, 2024
0c9a6f3
Fix typing
yeandy May 31, 2024
296119b
debug
yeandy May 31, 2024
58fec1e
Debug
yeandy May 31, 2024
74a1e70
Debug
yeandy May 31, 2024
70dd34c
Fix
yeandy May 31, 2024
bd29b6e
Lint
yeandy May 31, 2024
65e12a4
Merge
yeandy Jun 3, 2024
746bbda
Remove redundant
yeandy Jun 3, 2024
3d0214e
Fix
yeandy Jun 4, 2024
0151bb0
Address comments
yeandy Jun 6, 2024
951200e
Fix
yeandy Jun 6, 2024
22cae3b
Fix typo
yeandy Jun 10, 2024
081891f
Remove unused attention
yeandy Jun 10, 2024
0be7296
Fix typo
yeandy Jun 10, 2024
b6fd256
Fix
yeandy Jun 10, 2024
220cfd2
Merge branch 'main' into mor--kv-cache-layout-reformat-output
yeandy Jun 11, 2024
4b9c8a3
Merge branch 'main' into mor--kv-cache-layout-reformat-output
yeandy Jun 17, 2024
7fd12c6
Merge branch 'main' into mor--kv-cache-layout-reformat-output
yeandy Jun 17, 2024
7c9ccae
Merge branch 'main' into mor--kv-cache-layout-reformat-output
yeandy Jun 18, 2024
5d798a6
change the sweeping to prefill and ar cache only (#712)
morgandu Jun 18, 2024
4f0a12c
Remove unused
yeandy Jun 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions MaxText/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,16 @@
ScanIn = partitioning.ScanIn

AxisNames = tuple[str, ...]
AxisIdxes = tuple[int, ...]

BATCH = "activation_batch"
LENGTH = "activation_length"
HEAD = "activation_heads"
D_KV = "activation_kv"
CACHE_BATCH = "cache_batch"
CACHE_SEQUENCE = "cache_sequence"
CACHE_HEADS = "cache_heads"
CACHE_KV = "cache_kv"

MODEL_MODE_AUTOREGRESSIVE = "autoregressive"
MODEL_MODE_PREFILL = "prefill"
Expand Down
9 changes: 9 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,15 @@ inference_microbenchmark_prefill_lengths: "64,128,256,512,1024"
inference_microbenchmark_stages: "prefill,generate"
inference_microbenchmark_loop_iters: 10
inference_microbenchmark_log_file_path: ""
inference_metadata_file: "" # path to a json file

# KV Cache layout control
# Logical layout: 0,1,2,3 ; CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV
# Default layout: 1,2,0,3 ; CACHE_SEQUENCE, CACHE_HEADS, CACHE_BATCH, CACHE_KV
prefill_key_axis_order: "1,2,0,3"
prefill_value_axis_order: "1,2,0,3"
ar_key_axis_order: "1,2,0,3"
ar_value_axis_order: "1,2,0,3"

# Checkpoint Structured logging
enable_checkpoint_cloud_logger: False
Expand Down
33 changes: 30 additions & 3 deletions MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import json
import sys

from collections.abc import MutableMapping
from typing import Any, Dict, Optional

from jetstream.engine import token_utils

import max_utils
Expand Down Expand Up @@ -170,10 +173,25 @@ def collate_results(config, results, model_size, cache_size, num_model_params, i
return results


def write_results(results, filename):
def flatten_dict(dictionary, prefix='', sep='_'):
results = []
for k, v in dictionary.items():
new_key = str(prefix) + sep + str(k) if prefix else k
if isinstance(v, MutableMapping):
results.extend(flatten_dict(v, new_key, sep=sep).items())
else:
results.append((new_key, v))
return dict(results)


def write_results(results, filename, flatten_microbenchmark_results):
yeandy marked this conversation as resolved.
Show resolved Hide resolved
"""Write the results microbenchmark results to a json file."""
if flatten_microbenchmark_results:
results['flattened_results'] = flatten_dict(results)
if filename != "":
with open(filename, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
return results


def print_results_for_analyze(results):
Expand Down Expand Up @@ -218,7 +236,7 @@ def summarize_prefill_result(engine, params, tokens, true_length):
}


def main(config):
def main(config, inference_metadata: Optional[Dict[str, Any]] = None):
engine = maxengine.MaxEngine(config)
params = engine.load_params()
prefill_lengths = [int(l) for l in config.inference_microbenchmark_prefill_lengths.split(",")]
Expand Down Expand Up @@ -277,8 +295,17 @@ def main(config):
config, engine, params, decode_state, engine.max_concurrent_decodes, cache_size, model_size, benchmark_loop_iters)

results = collate_results(config, benchmark_results, model_size, cache_size, num_model_params)
write_results(results, filename=config.inference_microbenchmark_log_file_path)
print_results_for_analyze(results)
if inference_metadata:
flatten_microbenchmark_results = pyconfig.string_to_bool(inference_metadata.get('flatten_microbenchmark_results', 'false'))
else:
flatten_microbenchmark_results = 'false'
results = write_results(
results,
filename=config.inference_microbenchmark_log_file_path,
flatten_microbenchmark_results=flatten_microbenchmark_results
)
return results


if __name__ == "__main__":
Expand Down
154 changes: 154 additions & 0 deletions MaxText/inference_microbenchmark_sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""
Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""Sweep across inference microbenchmarks."""

import os
import sys
import json
import jsonlines
import inference_microbenchmark
import max_utils
import pyconfig
from jax._src.lib import xla_extension


def main():
"""
User needs to set the config's inference_metadata_file, which is a path to a
json file.

This json should contain the following keys:
- key_value_axis_order_product_id_list: comma separated string of key_value_axis_order_product_id
- prefill_key_axis_order_list: comma delimited string of prefill_key_axis_order
- prefill_value_axis_order_list: comma delimited string of prefill_value_axis_order
- ar_key_axis_order_list: comma delimited string of ar_key_axis_order
- ar_value_axis_order_list: comma delimited string of ar_value_axis_order
- accelerator: name of the accelerator
- flatten_microbenchmark_results: Whether or not to flatten results. Should
be true
"""
pyconfig.initialize(sys.argv)
config = pyconfig.config
yeandy marked this conversation as resolved.
Show resolved Hide resolved

with open(config.inference_metadata_file, encoding='utf-8') as json_file:
inference_metadata = json.load(json_file)
print(f"inference_metadata: {inference_metadata}")
yeandy marked this conversation as resolved.
Show resolved Hide resolved

key_value_axis_order_product_id_list = inference_metadata['key_value_axis_order_product_id_list'].split(':')
prefill_key_axis_order_list = inference_metadata['prefill_key_axis_order_list'].split(':')
prefill_value_axis_order_list = inference_metadata['prefill_value_axis_order_list'].split(':')
ar_key_axis_order_list = inference_metadata['ar_key_axis_order_list'].split(':')
ar_value_axis_order_list = inference_metadata['ar_value_axis_order_list'].split(':')

yeandy marked this conversation as resolved.
Show resolved Hide resolved
results = []
for (
key_value_axis_order_product_id,
prefill_key_axis_order,
prefill_value_axis_order,
ar_key_axis_order,
ar_value_axis_order,
) in zip(
key_value_axis_order_product_id_list,
prefill_key_axis_order_list,
prefill_value_axis_order_list,
ar_key_axis_order_list,
ar_value_axis_order_list,
):
print(f"key_value_axis_order_product_id {key_value_axis_order_product_id}")
print(f"prefill_key_axis_order {prefill_key_axis_order}")
print(f"prefill_value_axis_order {prefill_value_axis_order}")
print(f"ar_key_axis_order {ar_key_axis_order}")
print(f"ar_value_axis_order {ar_value_axis_order}")

# Manually update the config
# Don't set key_value_axis_order_product_id; otherwise it will recompute
# ar_key_axis_order and ar_value_axis_order
quant = 'bf16' if not config.quantization else config.quantization
run_name = (
f"{inference_metadata['accelerator']}-{config.model_name}-"
f"{quant}-{key_value_axis_order_product_id}-{prefill_key_axis_order}-"
f"{ar_key_axis_order}"
)
tensorboard_dir = os.path.join(config.base_output_directory, run_name, "tensorboard", "")
checkpoint_dir = os.path.join(config.base_output_directory, run_name, "checkpoint", "")
metrics_dir = os.path.join(config.base_output_directory, run_name, "metrics", "")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are quant and quantize_kvcache and different combination of these two, as discussed, we will create different test_config in xlml, and the base_run_name should already have all the information to differentiate the runs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

yeandy marked this conversation as resolved.
Show resolved Hide resolved
pyconfig._config.keys['prefill_key_axis_order'] = prefill_key_axis_order # pylint: disable=protected-access
pyconfig._config.keys['prefill_value_axis_order'] = prefill_value_axis_order # pylint: disable=protected-access
pyconfig._config.keys['ar_key_axis_order'] = ar_key_axis_order # pylint: disable=protected-access
pyconfig._config.keys['ar_value_axis_order'] = ar_value_axis_order # pylint: disable=protected-access
pyconfig._config.keys['tensorboard_dir'] = tensorboard_dir # pylint: disable=protected-access
pyconfig._config.keys['checkpoint_dir'] = checkpoint_dir # pylint: disable=protected-access
pyconfig._config.keys['metrics_dir'] = metrics_dir # pylint: disable=protected-access
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think checkpoint_dir and metrics_dir are used at all?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. I saw that they had the wrong values, so decided to update. I can remove to reduce confusion.

yeandy marked this conversation as resolved.
Show resolved Hide resolved
pyconfig._config.keys['run_name'] = run_name # pylint: disable=protected-access
max_utils.write_config_raw_keys_for_gcs(pyconfig._config.keys) # pylint: disable=protected-access
Copy link
Collaborator

@morgandu morgandu Jun 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

@yeandy yeandy Jun 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that write_config_raw_keys_for_gcs is called once during initialize (https://github.com/google/maxtext/blob/main/MaxText/pyconfig.py#L225), and writes the default values of prefill and ar axis orders to GCS. And so each time we loop for a different prefill/ar axis, we need to explicitly call max_utils.write_config_raw_keys_for_gcs again to make sure we write the updated values to GCS.

Did you take this out in your code, and it worked for you?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@morgandu should I do a test without max_utils.write_config_raw_keys_for_gcs?


# Prepare metadata (dimensions) json for XLML
dimensions_json = {
"base_output_directory": config.base_output_directory,
"model_name": config.model_name,
"tokenizer": config.tokenizer_path,
"weight_dtype": config.weight_dtype,
"inference_microbenchmark_prefill_lengths": f"{config.inference_microbenchmark_prefill_lengths}",
"inference_microbenchmark_stages": config.inference_microbenchmark_stages,
"inference_microbenchmark_loop_iters": f"{config.inference_microbenchmark_loop_iters}",
"max_prefill_predict_length": f"{config.max_prefill_predict_length}",
"max_target_length": f"{config.max_target_length}",
"per_device_batch_size": f"{config.per_device_batch_size}",
"ici_fsdp_parallelism": f"{config.ici_fsdp_parallelism}",
"ici_autoregressive_parallelism": f"{config.ici_autoregressive_parallelism}",
"ici_tensor_parallelism": f"{config.ici_tensor_parallelism}",
"profiler": f"{config.profiler}",
"scan_layers": f"{config.scan_layers}",
"quantization": config.quantization,
"quantize_kvcache": f"{config.quantize_kvcache}",
"attention": config.attention,
"key_value_axis_order_product_id": f"{key_value_axis_order_product_id}",
"prefill_key_axis_order": f"{prefill_key_axis_order}",
"prefill_value_axis_order": f"{prefill_value_axis_order}",
"ar_key_axis_order": f"{ar_key_axis_order}",
"ar_value_axis_order": f"{ar_value_axis_order}",
yeandy marked this conversation as resolved.
Show resolved Hide resolved
"config_json_string": json.dumps(
pyconfig._config.keys, # pylint: disable=protected-access
default=lambda x: f"<<non-serializable: {type(x).__qualname__}>>"
)
}
dimensions_json = {
**dimensions_json,
**inference_metadata,
}
try:
microbenchmark_results = inference_microbenchmark.main(config, inference_metadata=inference_metadata)
metrics = microbenchmark_results['flattened_results']
metrics = {k.lower(): v for k, v in metrics.items()}
dimensions_json['oom'] = 'False'
yeandy marked this conversation as resolved.
Show resolved Hide resolved
except xla_extension.XlaRuntimeError:
# OOM
metrics = {}
dimensions_json['oom'] = 'True'
yeandy marked this conversation as resolved.
Show resolved Hide resolved

final = {'metrics': metrics, 'dimensions': dimensions_json}
print(f"Result: {final}")
results.append(final)

print(f"All results {results}")
path = 'inference_microbenchmark_sweep_results.jsonl'
with jsonlines.open(path, mode="w") as writer:
writer.write_all(results)


if __name__ == "__main__":
main()