diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index da3a5b584..4eedaf99b 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -323,6 +323,7 @@ inference_microbenchmark_prefill_lengths: "64,128,256,512,1024" inference_microbenchmark_stages: "prefill,generate" inference_microbenchmark_loop_iters: 10 inference_microbenchmark_log_file_path: "" +inference_metadata_file: "" # path to a json file # KV Cache layout control # Logical layout: 0,1,2,3 ; CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV diff --git a/MaxText/inference_microbenchmark.py b/MaxText/inference_microbenchmark.py index 3d52d2129..51bb85aaf 100644 --- a/MaxText/inference_microbenchmark.py +++ b/MaxText/inference_microbenchmark.py @@ -20,6 +20,9 @@ import json import sys +from collections.abc import MutableMapping +from typing import Any, Dict, Optional + from jetstream.engine import token_utils import max_utils @@ -63,9 +66,9 @@ def prefill_benchmark( f"\tPrefill TFLOPs/sec/device: {tflops_per_sec_per_device:.3f}\n\n\n\n" ) result_dict = { - "prefill_time_in_ms": prefill_average_ms, - "prefill_total_tflops_per_device": prefill_tflops_per_device, - "prefill_tflops_per_sec_per_device": tflops_per_sec_per_device, + "time_in_ms": prefill_average_ms, + "total_tflops_per_device": prefill_tflops_per_device, + "tflops_per_sec_per_device": tflops_per_sec_per_device, } return result_dict @@ -106,7 +109,7 @@ def prefill_insert_benchmark( f"\tPrefill + Insert step average time: {prefill_insert_average_ms:.3f} ms\n\n\n\n" ) result_dict = { - "prefill_insert_time_in_ms": prefill_insert_average_ms + "time_in_ms": prefill_insert_average_ms } return result_dict, decode_state @@ -147,11 +150,11 @@ def ar_benchmark(config, engine, params, decode_state, global_batch_size, cache_ ) result_dict = { - "ar_step_in_ms": ar_average_ms, - "ar_step_in_ms_per_seq": ar_average_ms / global_batch_size, - "ar_global_batch_size": global_batch_size, - "ar_total_throughput_tokens_per_second": total_throughput, - "ar_device_bandwidth_GB_per_second": bw_per_device, + "step_in_ms": ar_average_ms, + "step_in_ms_per_seq": ar_average_ms / global_batch_size, + "global_batch_size": global_batch_size, + "total_throughput_tokens_per_second": total_throughput, + "device_bandwidth_GB_per_second": bw_per_device, } return result_dict, decode_state @@ -159,8 +162,8 @@ def ar_benchmark(config, engine, params, decode_state, global_batch_size, cache_ def collate_results(config, results, model_size, cache_size, num_model_params, incl_config=False): """Adds model/cache size info and optionally config info to results.""" results["sizes"] = { - "Model_size_in_GB": model_size / 1e9, - "cache_size_in_GB": cache_size / 1e9, + "model_size_in_gb": model_size / 1e9, + "cache_size_in_gb": cache_size / 1e9, "model_params_in_billions": num_model_params / 1e9, } if incl_config: @@ -170,30 +173,45 @@ def collate_results(config, results, model_size, cache_size, num_model_params, i return results -def write_results(results, filename): +def flatten_dict(dictionary, prefix='', sep='_'): + results = [] + for k, v in dictionary.items(): + new_key = str(prefix) + sep + str(k) if prefix else k + if isinstance(v, MutableMapping): + results.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + results.append((new_key, v)) + return dict(results) + + +def write_results(results, filename, flatten_microbenchmark_results): + """Write the results microbenchmark results to a json file.""" + if flatten_microbenchmark_results: + results['flattened_results'] = flatten_dict(results) if filename != "": with open(filename, "w", encoding="utf-8") as f: json.dump(results, f, indent=2) + return results def print_results_for_analyze(results): """Print results.""" print("\nFor usage in analyze_sharegpt.py :") - if "Prefill" in results: + if "prefill" in results: prefill_bucket_size_to_ms = {} - for k, v in results["Prefill"].items(): - prefill_bucket_size_to_ms[int(k)] = round(v["prefill_time_in_ms"], 3) + for k, v in results["prefill"].items(): + prefill_bucket_size_to_ms[int(k)] = round(v["time_in_ms"], 3) print(f"PREFILL_BUCKET_SIZE_TO_MS = {prefill_bucket_size_to_ms}") - if "Prefill_Insert" in results: + if "insert" in results: insert_bucket_size_to_ms = {} - for k, v in results["Prefill_Insert"].items(): - insert_bucket_size_to_ms[int(k)] = round(v["prefill_insert_time_in_ms"], 3) - print(f"PREFILL_INSERT_BUCKET_SIZE_TO_MS = {insert_bucket_size_to_ms}") + for k, v in results["insert"].items(): + insert_bucket_size_to_ms[int(k)] = round(v["time_in_ms"], 3) + print(f"INSERT_BUCKET_SIZE_TO_MS = {insert_bucket_size_to_ms}") - if "AutoRegressive" in results: - print(f"SYSTEM_TIME_PER_DECODE_TOKEN_MS = {results['AutoRegressive']['ar_step_in_ms_per_seq']}") + if "autoregressive" in results: + print(f"SYSTEM_TIME_PER_DECODE_TOKEN_MS = {results['autoregressive']['step_in_ms_per_seq']}") def summarize_prefill_result(engine, params, tokens, true_length): @@ -209,16 +227,16 @@ def summarize_prefill_result(engine, params, tokens, true_length): ) del prefill_result return { - "num_prefill_logits_params": num_prefill_logits_params, - "total_prefill_logits_size": total_prefill_logits_size, - "avg_prefill_logits_param_size": avg_prefill_logits_param_size, - "num_prefill_cache_params": num_prefill_cache_params, - "total_prefill_cache_size": total_prefill_cache_size, - "avg_prefill_cache_param_size": avg_prefill_cache_param_size, + "num_logits_params": num_prefill_logits_params, + "total_logits_size": total_prefill_logits_size, + "avg_logits_param_size": avg_prefill_logits_param_size, + "num_cache_params": num_prefill_cache_params, + "total_cache_size": total_prefill_cache_size, + "avg_cache_param_size": avg_prefill_cache_param_size, } -def main(config): +def main(config, inference_metadata: Optional[Dict[str, Any]] = None): engine = maxengine.MaxEngine(config) params = engine.load_params() prefill_lengths = [int(l) for l in config.inference_microbenchmark_prefill_lengths.split(",")] @@ -236,9 +254,9 @@ def main(config): benchmark_results = {} if "prefill" in stages_to_benchmark: - benchmark_results["Prefill_Result"] = {} - benchmark_results["Prefill"] = {} - benchmark_results["Prefill_Insert"] = {} + benchmark_results["prefill-result-sizes"] = {} + benchmark_results["prefill"] = {} + benchmark_results["insert"] = {} prefill_tokens = {} prefill_true_lengths = {} @@ -246,12 +264,12 @@ def main(config): prefill_tokens[prefill_length], prefill_true_lengths[prefill_length] = token_utils.tokenize_and_pad( text, vocab, is_bos=True, prefill_lengths=[prefill_length] ) - benchmark_results["Prefill_Result"]["prefill_length"] = summarize_prefill_result( + benchmark_results["prefill-result-sizes"][prefill_length] = summarize_prefill_result( engine, params, prefill_tokens[prefill_length], prefill_true_lengths[prefill_length] ) for prefill_length in prefill_lengths: - benchmark_results["Prefill"][prefill_length] = prefill_benchmark( + benchmark_results["prefill"][prefill_length] = prefill_benchmark( config, engine, params, @@ -261,7 +279,7 @@ def main(config): benchmark_loop_iters ) - benchmark_results["Prefill_Insert"][prefill_length], decode_state = prefill_insert_benchmark( + prefill_insert_time, decode_state = prefill_insert_benchmark( config, engine, decode_state, @@ -271,14 +289,29 @@ def main(config): prefill_true_lengths[prefill_length], benchmark_loop_iters ) + benchmark_results["insert"][prefill_length] = {} + benchmark_results["insert"][prefill_length]["time_in_ms"] = ( + prefill_insert_time["time_in_ms"] - benchmark_results["prefill"][prefill_length]["time_in_ms"] + ) if "generate" in stages_to_benchmark: - benchmark_results["AutoRegressive"], decode_state = ar_benchmark( + benchmark_results["autoregressive"], decode_state = ar_benchmark( config, engine, params, decode_state, engine.max_concurrent_decodes, cache_size, model_size, benchmark_loop_iters) results = collate_results(config, benchmark_results, model_size, cache_size, num_model_params) - write_results(results, filename=config.inference_microbenchmark_log_file_path) print_results_for_analyze(results) + if inference_metadata: + flatten_microbenchmark_results = pyconfig.string_to_bool( + inference_metadata.get('flatten_microbenchmark_results', 'false') + ) + else: + flatten_microbenchmark_results = 'false' + results = write_results( + results, + filename=config.inference_microbenchmark_log_file_path, + flatten_microbenchmark_results=flatten_microbenchmark_results + ) + return results if __name__ == "__main__": diff --git a/MaxText/inference_microbenchmark_sweep.py b/MaxText/inference_microbenchmark_sweep.py new file mode 100644 index 000000000..8f7ddda5b --- /dev/null +++ b/MaxText/inference_microbenchmark_sweep.py @@ -0,0 +1,143 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Sweep across inference microbenchmarks.""" + +import os +import sys +import json +import jsonlines +import inference_microbenchmark +import pyconfig +from jax._src.lib import xla_extension + + +def main(): + """ + User needs to set the config's inference_metadata_file, which is a path to a + json file. + + This json should contain the following keys: + - two_axis_order_product_id_list: comma separated string of two_axis_order_product_id + - prefill_cache_axis_order_list: comma delimited string of prefill_cache_axis_order + - ar_cache_axis_order_list: comma delimited string of ar_cache_axis_order + - accelerator: name of the accelerator + - flatten_microbenchmark_results: Whether or not to flatten results. Should + be true + """ + pyconfig.initialize(sys.argv) + config = pyconfig.config + base_run_name = config.run_name + + with open(config.inference_metadata_file, encoding='utf-8') as json_file: + inference_metadata = json.load(json_file) + print(f"inference_metadata: {inference_metadata}") + + two_axis_order_product_id_list = inference_metadata['two_axis_order_product_id_list'].split(':') + prefill_cache_axis_order_list = inference_metadata['prefill_cache_axis_order_list'].split(':') + ar_cache_axis_order_list = inference_metadata['ar_cache_axis_order_list'].split(':') + + start_two_axis_order_product_id = two_axis_order_product_id_list[0] + end_two_axis_order_product_id = two_axis_order_product_id_list[-1] + + results = [] + for ( + two_axis_order_product_id, + prefill_cache_axis_order, + ar_cache_axis_order, + ) in zip( + two_axis_order_product_id_list, + prefill_cache_axis_order_list, + ar_cache_axis_order_list, + ): + print(f"two_axis_order_product_id {two_axis_order_product_id}") + print(f"prefill_cache_axis_order {prefill_cache_axis_order}") + print(f"ar_cache_axis_order {ar_cache_axis_order}") + + run_tag = ( + f"{two_axis_order_product_id}-{prefill_cache_axis_order.replace(',','')}-{ar_cache_axis_order.replace(',','')}" + ) + run_name = f"{base_run_name}/{run_tag}" + + tensorboard_dir = os.path.join(config.base_output_directory, run_name, "tensorboard", "") + pyconfig._config.keys['prefill_cache_axis_order'] = prefill_cache_axis_order # pylint: disable=protected-access + pyconfig._config.keys['ar_cache_axis_order'] = ar_cache_axis_order # pylint: disable=protected-access + pyconfig._config.keys['tensorboard_dir'] = tensorboard_dir # pylint: disable=protected-access + pyconfig._config.keys['run_name'] = run_name # pylint: disable=protected-access + + # Prepare metadata (dimensions) json for XLML + dimensions_json = { + "base_output_directory": config.base_output_directory, + "model_name": config.model_name, + "tokenizer": config.tokenizer_path, + "weight_dtype": config.weight_dtype, + "inference_microbenchmark_prefill_lengths": f"{config.inference_microbenchmark_prefill_lengths}", + "inference_microbenchmark_stages": config.inference_microbenchmark_stages, + "inference_microbenchmark_loop_iters": f"{config.inference_microbenchmark_loop_iters}", + "max_prefill_predict_length": f"{config.max_prefill_predict_length}", + "max_target_length": f"{config.max_target_length}", + "per_device_batch_size": f"{config.per_device_batch_size}", + "ici_fsdp_parallelism": f"{config.ici_fsdp_parallelism}", + "ici_autoregressive_parallelism": f"{config.ici_autoregressive_parallelism}", + "ici_tensor_parallelism": f"{config.ici_tensor_parallelism}", + "profiler": f"{config.profiler}", + "scan_layers": f"{config.scan_layers}", + "quantization": config.quantization, + "quantize_kvcache": f"{config.quantize_kvcache}", + "attention": config.attention, + "two_axis_order_product_id": f"{two_axis_order_product_id}", + "prefill_cache_axis_order": f"{prefill_cache_axis_order}", + "ar_cache_axis_order": f"{ar_cache_axis_order}", + "compute_axis_order": f"{config.compute_axis_order}", + "reshape_q": f"{config.reshape_q}", + "kv_quant_axis": f"{config.kv_quant_axis}", + "run_name": f"{run_name}", + "run_tag": f"{run_tag}", + "config_json_string": json.dumps( + pyconfig._config.keys, # pylint: disable=protected-access + default=lambda x: f"<>" + ) + } + dimensions_json = { + **dimensions_json, + **inference_metadata, + } + try: + microbenchmark_results = inference_microbenchmark.main(config, inference_metadata=inference_metadata) + metrics = microbenchmark_results['flattened_results'] + metrics = {k.lower(): v for k, v in metrics.items()} + dimensions_json['oom'] = 'False' + print(f"Completed run {two_axis_order_product_id} out of: " + f"{start_two_axis_order_product_id} to {end_two_axis_order_product_id}") + except xla_extension.XlaRuntimeError: + # OOM + metrics = {} + dimensions_json['oom'] = 'True' + print(f"Failed at run {two_axis_order_product_id} out of: " + f"{start_two_axis_order_product_id} to {end_two_axis_order_product_id}") + + final = {'metrics': metrics, 'dimensions': dimensions_json} + print(f"Result: {final}") + results.append(final) + + print(f"All results {results}") + path = 'inference_microbenchmark_sweep_results.jsonl' + with jsonlines.open(path, mode="w") as writer: + writer.write_all(results) + + +if __name__ == "__main__": + main()