Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sweep across KV cache layouts #662

Open
wants to merge 77 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
91f4398
layout
morgandu May 16, 2024
18ba546
Add flag to flatten
yeandy May 20, 2024
e78faee
Update config
yeandy May 20, 2024
6c04467
tmp
yeandy May 20, 2024
b839d63
Fix delimiter
yeandy May 20, 2024
baaeda7
Fix delimiter
yeandy May 20, 2024
f0588ee
Fix delimiter
yeandy May 20, 2024
d85990b
Fix
yeandy May 21, 2024
cbb90bf
Fix
yeandy May 21, 2024
929ad8d
Fix
yeandy May 21, 2024
f3cbd3e
Fix write
yeandy May 21, 2024
0a6de55
Fix
yeandy May 21, 2024
f907515
Fix
yeandy May 21, 2024
8333659
Fix
yeandy May 21, 2024
8c0707d
Fix
yeandy May 21, 2024
674ba29
Fix
yeandy May 21, 2024
95f846f
Fix
yeandy May 21, 2024
858f4e3
Catch OOM
yeandy May 21, 2024
1c20d61
Fix
yeandy May 22, 2024
a9b381a
Update tensorboard dir
yeandy May 22, 2024
ec1d76e
add prefill
yeandy May 22, 2024
63edcc2
Fix
yeandy May 22, 2024
268ee64
Fix
yeandy May 22, 2024
c1490f6
Fix
yeandy May 22, 2024
0fe1d51
Fix
yeandy May 22, 2024
554d014
Fix
yeandy May 22, 2024
6458a89
Fix
yeandy May 22, 2024
76bda5d
Fix
yeandy May 22, 2024
42bd412
Fix
yeandy May 22, 2024
df5c66e
Fix
yeandy May 23, 2024
c44da0c
Fix
yeandy May 23, 2024
d84f4c9
Fix
yeandy May 23, 2024
185563d
Add layout control
morgandu May 22, 2024
4291c4c
test
yeandy May 23, 2024
c638c63
test
yeandy May 23, 2024
128a691
test
yeandy May 23, 2024
8387cd7
test
yeandy May 23, 2024
b422007
test
yeandy May 24, 2024
f69f405
test
yeandy May 24, 2024
c37ad18
test
yeandy May 24, 2024
4b3be09
test
yeandy May 24, 2024
65c2289
test
yeandy May 24, 2024
d5637e4
merge
yeandy May 24, 2024
9c0d6a5
Fix
yeandy May 24, 2024
58405b1
Serialize config json
yeandy May 24, 2024
187cb3d
Add kv cache layout control and tests
morgandu May 30, 2024
958938c
Fix string concat
yeandy May 30, 2024
e37e495
Merge mor--kv-cache-layout
yeandy May 30, 2024
a7fb24b
Remove duplicates
yeandy May 30, 2024
cb2db18
Rename enable_profiler
yeandy May 30, 2024
7311225
json default val
yeandy May 30, 2024
000e935
Enable kv cache layout control
morgandu May 30, 2024
ceda588
Fix failed tests
yeandy May 30, 2024
69bfc92
Add instructions
yeandy May 30, 2024
ca107c5
merge
yeandy May 30, 2024
1d417e4
Address comments
yeandy May 31, 2024
0c9a6f3
Fix typing
yeandy May 31, 2024
296119b
debug
yeandy May 31, 2024
58fec1e
Debug
yeandy May 31, 2024
74a1e70
Debug
yeandy May 31, 2024
70dd34c
Fix
yeandy May 31, 2024
bd29b6e
Lint
yeandy May 31, 2024
65e12a4
Merge
yeandy Jun 3, 2024
746bbda
Remove redundant
yeandy Jun 3, 2024
3d0214e
Fix
yeandy Jun 4, 2024
0151bb0
Address comments
yeandy Jun 6, 2024
951200e
Fix
yeandy Jun 6, 2024
22cae3b
Fix typo
yeandy Jun 10, 2024
081891f
Remove unused attention
yeandy Jun 10, 2024
0be7296
Fix typo
yeandy Jun 10, 2024
b6fd256
Fix
yeandy Jun 10, 2024
220cfd2
Merge branch 'main' into mor--kv-cache-layout-reformat-output
yeandy Jun 11, 2024
4b9c8a3
Merge branch 'main' into mor--kv-cache-layout-reformat-output
yeandy Jun 17, 2024
7fd12c6
Merge branch 'main' into mor--kv-cache-layout-reformat-output
yeandy Jun 17, 2024
7c9ccae
Merge branch 'main' into mor--kv-cache-layout-reformat-output
yeandy Jun 18, 2024
5d798a6
change the sweeping to prefill and ar cache only (#712)
morgandu Jun 18, 2024
4f0a12c
Remove unused
yeandy Jun 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
64 changes: 32 additions & 32 deletions MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def prefill_benchmark(
f"\tPrefill TFLOPs/sec/device: {tflops_per_sec_per_device:.3f}\n\n\n\n"
)
result_dict = {
"prefill_time_in_ms": prefill_average_ms,
"prefill_total_tflops_per_device": prefill_tflops_per_device,
"prefill_tflops_per_sec_per_device": tflops_per_sec_per_device,
"time_in_ms": prefill_average_ms,
"total_tflops_per_device": prefill_tflops_per_device,
"tflops_per_sec_per_device": tflops_per_sec_per_device,
}
return result_dict

Expand Down Expand Up @@ -109,7 +109,7 @@ def prefill_insert_benchmark(
f"\tPrefill + Insert step average time: {prefill_insert_average_ms:.3f} ms\n\n\n\n"
)
result_dict = {
"prefill_insert_time_in_ms": prefill_insert_average_ms
"insert_time_in_ms": prefill_insert_average_ms
yeandy marked this conversation as resolved.
Show resolved Hide resolved
}
return result_dict, decode_state

Expand Down Expand Up @@ -150,20 +150,20 @@ def ar_benchmark(config, engine, params, decode_state, global_batch_size, cache_
)

result_dict = {
"ar_step_in_ms": ar_average_ms,
"ar_step_in_ms_per_seq": ar_average_ms / global_batch_size,
"ar_global_batch_size": global_batch_size,
"ar_total_throughput_tokens_per_second": total_throughput,
"ar_device_bandwidth_GB_per_second": bw_per_device,
"step_in_ms": ar_average_ms,
"step_in_ms_per_seq": ar_average_ms / global_batch_size,
"global_batch_size": global_batch_size,
"total_throughput_tokens_per_second": total_throughput,
"device_bandwidth_GB_per_second": bw_per_device,
yeandy marked this conversation as resolved.
Show resolved Hide resolved
}
return result_dict, decode_state


def collate_results(config, results, model_size, cache_size, num_model_params, incl_config=False):
"""Adds model/cache size info and optionally config info to results."""
results["sizes"] = {
"Model_size_in_GB": model_size / 1e9,
"cache_size_in_GB": cache_size / 1e9,
"model_size_in_gb": model_size / 1e9,
"cache_size_in_gb": cache_size / 1e9,
"model_params_in_billions": num_model_params / 1e9,
}
if incl_config:
Expand Down Expand Up @@ -198,20 +198,20 @@ def print_results_for_analyze(results):
"""Print results."""
print("\nFor usage in analyze_sharegpt.py :")

if "Prefill" in results:
if "prefill" in results:
prefill_bucket_size_to_ms = {}
for k, v in results["Prefill"].items():
prefill_bucket_size_to_ms[int(k)] = round(v["prefill_time_in_ms"], 3)
for k, v in results["prefill"].items():
prefill_bucket_size_to_ms[int(k)] = round(v["time_in_ms"], 3)
print(f"PREFILL_BUCKET_SIZE_TO_MS = {prefill_bucket_size_to_ms}")

if "Prefill_Insert" in results:
if "prefill-insert" in results:
insert_bucket_size_to_ms = {}
for k, v in results["Prefill_Insert"].items():
insert_bucket_size_to_ms[int(k)] = round(v["prefill_insert_time_in_ms"], 3)
for k, v in results["prefill-insert"].items():
insert_bucket_size_to_ms[int(k)] = round(v["insert_time_in_ms"], 3)
print(f"PREFILL_INSERT_BUCKET_SIZE_TO_MS = {insert_bucket_size_to_ms}")
yeandy marked this conversation as resolved.
Show resolved Hide resolved

if "AutoRegressive" in results:
print(f"SYSTEM_TIME_PER_DECODE_TOKEN_MS = {results['AutoRegressive']['ar_step_in_ms_per_seq']}")
if "autoregressive" in results:
print(f"SYSTEM_TIME_PER_DECODE_TOKEN_MS = {results['autoregressive']['step_in_ms_per_seq']}")


def summarize_prefill_result(engine, params, tokens, true_length):
Expand All @@ -227,12 +227,12 @@ def summarize_prefill_result(engine, params, tokens, true_length):
)
max_utils.delete_pytree(prefill_result)
return {
"num_prefill_logits_params": num_prefill_logits_params,
"total_prefill_logits_size": total_prefill_logits_size,
"avg_prefill_logits_param_size": avg_prefill_logits_param_size,
"num_prefill_cache_params": num_prefill_cache_params,
"total_prefill_cache_size": total_prefill_cache_size,
"avg_prefill_cache_param_size": avg_prefill_cache_param_size,
"num_logits_params": num_prefill_logits_params,
"total_logits_size": total_prefill_logits_size,
"avg_logits_param_size": avg_prefill_logits_param_size,
"num_cache_params": num_prefill_cache_params,
"total_cache_size": total_prefill_cache_size,
"avg_cache_param_size": avg_prefill_cache_param_size,
yeandy marked this conversation as resolved.
Show resolved Hide resolved
}


Expand All @@ -254,22 +254,22 @@ def main(config, inference_metadata: Optional[Dict[str, Any]] = None):
benchmark_results = {}
if "prefill" in stages_to_benchmark:

benchmark_results["Prefill_Result"] = {}
benchmark_results["Prefill"] = {}
benchmark_results["Prefill_Insert"] = {}
benchmark_results["prefill-result-sizes"] = {}
benchmark_results["prefill"] = {}
benchmark_results["prefill-insert"] = {}
yeandy marked this conversation as resolved.
Show resolved Hide resolved
prefill_tokens = {}
prefill_true_lengths = {}

for prefill_length in prefill_lengths:
prefill_tokens[prefill_length], prefill_true_lengths[prefill_length] = token_utils.tokenize_and_pad(
text, vocab, is_bos=True, prefill_lengths=[prefill_length]
)
benchmark_results["Prefill_Result"]["prefill_length"] = summarize_prefill_result(
benchmark_results["prefill-result-sizes"][prefill_length] = summarize_prefill_result(
engine, params, prefill_tokens[prefill_length], prefill_true_lengths[prefill_length]
)

for prefill_length in prefill_lengths:
benchmark_results["Prefill"][prefill_length] = prefill_benchmark(
benchmark_results["prefill"][prefill_length] = prefill_benchmark(
config,
engine,
params,
Expand All @@ -279,7 +279,7 @@ def main(config, inference_metadata: Optional[Dict[str, Any]] = None):
benchmark_loop_iters
)

benchmark_results["Prefill_Insert"][prefill_length], decode_state = prefill_insert_benchmark(
benchmark_results["prefill-insert"][prefill_length], decode_state = prefill_insert_benchmark(
yeandy marked this conversation as resolved.
Show resolved Hide resolved
config,
engine,
decode_state,
Expand All @@ -291,7 +291,7 @@ def main(config, inference_metadata: Optional[Dict[str, Any]] = None):
)

if "generate" in stages_to_benchmark:
benchmark_results["AutoRegressive"], decode_state = ar_benchmark(
benchmark_results["autoregressive"], decode_state = ar_benchmark(
config, engine, params, decode_state, engine.max_concurrent_decodes, cache_size, model_size, benchmark_loop_iters)

results = collate_results(config, benchmark_results, model_size, cache_size, num_model_params)
Expand Down
28 changes: 15 additions & 13 deletions MaxText/inference_microbenchmark_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def main():
"""
pyconfig.initialize(sys.argv)
config = pyconfig.config
yeandy marked this conversation as resolved.
Show resolved Hide resolved
base_run_name = config.run_name

with open(config.inference_metadata_file, encoding='utf-8') as json_file:
inference_metadata = json.load(json_file)
Expand All @@ -54,6 +55,9 @@ def main():
ar_key_axis_order_list = inference_metadata['ar_key_axis_order_list'].split(':')
ar_value_axis_order_list = inference_metadata['ar_value_axis_order_list'].split(':')

yeandy marked this conversation as resolved.
Show resolved Hide resolved
start_key_value_axis_order_product_id = key_value_axis_order_product_id_list[0]
end_key_value_axis_order_product_id = key_value_axis_order_product_id_list[-1]

results = []
for (
key_value_axis_order_product_id,
Expand All @@ -74,25 +78,17 @@ def main():
print(f"ar_key_axis_order {ar_key_axis_order}")
print(f"ar_value_axis_order {ar_value_axis_order}")

# Manually update the config
# Don't set key_value_axis_order_product_id; otherwise it will recompute
# ar_key_axis_order and ar_value_axis_order
quant = 'bf16' if not config.quantization else config.quantization
run_name = (
f"{inference_metadata['accelerator']}-{config.model_name}-"
f"{quant}-{key_value_axis_order_product_id}-{prefill_key_axis_order}-"
f"{ar_key_axis_order}"
run_tag = (
f"{key_value_axis_order_product_id}-{prefill_key_axis_order.replace(',','')}-{ar_key_axis_order.replace(',','')}"
)
run_name = f"{base_run_name}/{run_tag}"

tensorboard_dir = os.path.join(config.base_output_directory, run_name, "tensorboard", "")
checkpoint_dir = os.path.join(config.base_output_directory, run_name, "checkpoint", "")
metrics_dir = os.path.join(config.base_output_directory, run_name, "metrics", "")
pyconfig._config.keys['prefill_key_axis_order'] = prefill_key_axis_order # pylint: disable=protected-access
pyconfig._config.keys['prefill_value_axis_order'] = prefill_value_axis_order # pylint: disable=protected-access
pyconfig._config.keys['ar_key_axis_order'] = ar_key_axis_order # pylint: disable=protected-access
pyconfig._config.keys['ar_value_axis_order'] = ar_value_axis_order # pylint: disable=protected-access
pyconfig._config.keys['tensorboard_dir'] = tensorboard_dir # pylint: disable=protected-access
pyconfig._config.keys['checkpoint_dir'] = checkpoint_dir # pylint: disable=protected-access
pyconfig._config.keys['metrics_dir'] = metrics_dir # pylint: disable=protected-access
pyconfig._config.keys['run_name'] = run_name # pylint: disable=protected-access
max_utils.write_config_raw_keys_for_gcs(pyconfig._config.keys) # pylint: disable=protected-access
Copy link
Collaborator

@morgandu morgandu Jun 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

@yeandy yeandy Jun 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that write_config_raw_keys_for_gcs is called once during initialize (https://github.com/google/maxtext/blob/main/MaxText/pyconfig.py#L225), and writes the default values of prefill and ar axis orders to GCS. And so each time we loop for a different prefill/ar axis, we need to explicitly call max_utils.write_config_raw_keys_for_gcs again to make sure we write the updated values to GCS.

Did you take this out in your code, and it worked for you?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@morgandu should I do a test without max_utils.write_config_raw_keys_for_gcs?


Expand Down Expand Up @@ -121,6 +117,8 @@ def main():
"prefill_value_axis_order": f"{prefill_value_axis_order}",
"ar_key_axis_order": f"{ar_key_axis_order}",
"ar_value_axis_order": f"{ar_value_axis_order}",
yeandy marked this conversation as resolved.
Show resolved Hide resolved
"run_name": f"{run_name}",
"run_tag": f"{run_tag}",
"config_json_string": json.dumps(
pyconfig._config.keys, # pylint: disable=protected-access
default=lambda x: f"<<non-serializable: {type(x).__qualname__}>>"
Expand All @@ -135,10 +133,14 @@ def main():
metrics = microbenchmark_results['flattened_results']
metrics = {k.lower(): v for k, v in metrics.items()}
dimensions_json['oom'] = 'False'
yeandy marked this conversation as resolved.
Show resolved Hide resolved
print(f"Completed run {key_value_axis_order_product_id} out of: "
f"{start_key_value_axis_order_product_id} to {end_key_value_axis_order_product_id}")
except xla_extension.XlaRuntimeError:
# OOM
metrics = {}
dimensions_json['oom'] = 'True'
yeandy marked this conversation as resolved.
Show resolved Hide resolved
print(f"Failed at run {key_value_axis_order_product_id} out of: "
f"{start_key_value_axis_order_product_id} to {end_key_value_axis_order_product_id}")

final = {'metrics': metrics, 'dimensions': dimensions_json}
print(f"Result: {final}")
Expand All @@ -151,4 +153,4 @@ def main():


if __name__ == "__main__":
main()
main()