From 6c5e5ef236402ab829c2bb83d5f9ef29af3720f7 Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Fri, 25 Jul 2025 14:48:41 -0500 Subject: [PATCH 1/4] :sparkles: Add compilation time test Signed-off-by: gkumbhat --- tests/testing/test_compilation.py | 138 ++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 tests/testing/test_compilation.py diff --git a/tests/testing/test_compilation.py b/tests/testing/test_compilation.py new file mode 100644 index 00000000..5535e562 --- /dev/null +++ b/tests/testing/test_compilation.py @@ -0,0 +1,138 @@ +"""This module contains test related to compilation operation""" + +# Standard +import itertools +import os +import pytest +import time + +# Third Party +from torch import distributed as dist +import torch + +# First Party +from fms.models import get_model +from fms.utils import generation, tokenizers +from fms.utils.generation import pad_input_ids + +# Local +from aiu_fms_testing_utils.utils import ids_for_prompt, sample_sharegpt_requests, warmup_model +from aiu_fms_testing_utils.utils.aiu_setup import dprint + +GRANITE_3p3_8B_INSTRUCT = "ibm-granite/granite-3.3-8b-instruct" +SHARE_GPT_DATASET_PATH = os.environ.get( + "SHARE_GPT_DATASET_PATH", os.path.expanduser("~/share_gpt.json") +) + +ATTN_NAME = "spyre_paged_attn" + +compile_dynamic_sendnn = True + +common_model_paths = [GRANITE_3p3_8B_INSTRUCT] +common_batch_sizes = [1] +common_seq_lengths = [256] +common_shape_types = ["dynamic"] +common_max_new_tokens = [128] +common_expected_comp_time = [10] # In minutes + +if compile_dynamic_sendnn: + os.environ["VLLM_DT_MAX_CONTEXT_LEN"] = str( + (((max(common_seq_lengths) + max(common_max_new_tokens)) // 64) + 1) * 64 + ) + os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str(max(max(common_batch_sizes), 2)) + +common_shapes = list( + zip( + common_model_paths, + common_shape_types, + common_batch_sizes, + common_seq_lengths, + common_max_new_tokens, + common_expected_comp_time, + ) +) + + +# TODO: This is copied from test_decoders.py would be good to consolidate +def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0): + prompts_and_sizes = sample_sharegpt_requests( + SHARE_GPT_DATASET_PATH, + batch_size, + tokenizer, + int(seq_length / 2), + seq_length, + seed, + ) + prompt_list = [] + for prompt, _ in prompts_and_sizes: + prompt_list.append(ids_for_prompt(prompt, tokenizer)) + + input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length) + return input_ids, extra_kwargs + + +@pytest.fixture(autouse=True) +def reset_compiler(): + yield # run the test + if not compile_dynamic_sendnn: + torch.compiler.reset() + torch._dynamo.reset() + os.environ.pop("COMPILATION_MODE", None) + + +@pytest.mark.parametrize( + "model_path,shape_type,batch_size,seq_length,max_new_tokens,expected_comp_time", common_shapes +) +def test_compilation_time(model_path, shape_type, batch_size, seq_length, max_new_tokens, expected_comp_time): + """Test to validate time taken for model compilation.""" + torch.manual_seed(42) + torch.set_default_dtype(torch.float16) + os.environ["COMPILATION_MODE"] = "offline_decoder" + + dprint( + f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}" + ) + + if os.path.exists(model_path): + model_path_kwargs = {"model_path": model_path} + else: + model_path_kwargs = {"variant": model_path} + + tokenizer = tokenizers.get_tokenizer(model_path) + + # prepare the AIU model + model = get_model( + architecture="hf_pretrained", + device_type="cpu", + data_type= torch.float16, + fused_weights=False, + **model_path_kwargs, + ) + + model.eval() + torch.set_grad_enabled(False) + + # prepare input_ids + input_ids, extra_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer) + extra_kwargs["attn_name"] = ATTN_NAME + + start_time = time.perf_counter() + if shape_type == "dynamic": + compile_dynamic_sendnn = True + else: + compile_dynamic_sendnn = False + + model.compile( + backend="sendnn", options={"sendnn.dynamic": compile_dynamic_sendnn} + ) + warmup_model( + model, + input_ids, + max_new_tokens, + compile_dynamic_sendnn, + use_cache=False, + **extra_kwargs + ) + end_time = time.perf_counter() + + assert (end_time - start_time) < expected_comp_time * 60 \ No newline at end of file From e7c18f851994f36a31b9b04ad447459508bba6c4 Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Mon, 28 Jul 2025 21:47:55 -0500 Subject: [PATCH 2/4] :art: Fix formatting Signed-off-by: gkumbhat --- aiu_fms_testing_utils/utils/__init__.py | 13 +++++ tests/testing/test_compilation.py | 76 ++++++++++++++----------- 2 files changed, 57 insertions(+), 32 deletions(-) diff --git a/aiu_fms_testing_utils/utils/__init__.py b/aiu_fms_testing_utils/utils/__init__.py index 9a14d083..e5721b68 100644 --- a/aiu_fms_testing_utils/utils/__init__.py +++ b/aiu_fms_testing_utils/utils/__init__.py @@ -67,6 +67,19 @@ def warmup_model( dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s") +def get_env_to_int_list(env_var_name, default): + """Utility function to convert list of strings passed as given environment variable to + list of integers + """ + env_var_string = os.environ.get(env_var_name, default=default) + if not env_var_string: + return [] + if isinstance(env_var_string, list): + return env_var_string + + return [int(v) for v in env_var_string.split(",") if not isinstance(v, int)] + + def ids_for_prompt(prompt, tokenizer): tokens = tokenizer.tokenize(prompt) ids = tokenizer.convert_tokens_to_ids(tokens) diff --git a/tests/testing/test_compilation.py b/tests/testing/test_compilation.py index 5535e562..9c4a3eae 100644 --- a/tests/testing/test_compilation.py +++ b/tests/testing/test_compilation.py @@ -1,22 +1,25 @@ """This module contains test related to compilation operation""" # Standard -import itertools import os import pytest import time # Third Party -from torch import distributed as dist import torch # First Party from fms.models import get_model -from fms.utils import generation, tokenizers +from fms.utils import tokenizers from fms.utils.generation import pad_input_ids # Local -from aiu_fms_testing_utils.utils import ids_for_prompt, sample_sharegpt_requests, warmup_model +from aiu_fms_testing_utils.utils import ( + ids_for_prompt, + get_env_to_int_list, + sample_sharegpt_requests, + warmup_model, +) from aiu_fms_testing_utils.utils.aiu_setup import dprint GRANITE_3p3_8B_INSTRUCT = "ibm-granite/granite-3.3-8b-instruct" @@ -26,25 +29,39 @@ ATTN_NAME = "spyre_paged_attn" -compile_dynamic_sendnn = True +COMPILE_DYNAMIC_SHAPE = True + + +common_model_paths = get_env_to_int_list("COMMON_MODEL_NAME", [GRANITE_3p3_8B_INSTRUCT]) +common_batch_sizes = get_env_to_int_list("FMS_TEST_SHAPES_COMMON_BATCH_SIZES", [1]) +common_seq_lengths = get_env_to_int_list("FMS_TEST_SHAPES_COMMON_SEQ_LENGTHS", [64]) +common_max_new_tokens = get_env_to_int_list( + "FMS_TEST_SHAPES_COMMON_MAX_NEW_TOKENS", [64] +) +common_expected_comp_time = get_env_to_int_list( + "COMMON_COMPILATION_EXPECTED_TIME", [10] +) # In minutes -common_model_paths = [GRANITE_3p3_8B_INSTRUCT] -common_batch_sizes = [1] -common_seq_lengths = [256] -common_shape_types = ["dynamic"] -common_max_new_tokens = [128] -common_expected_comp_time = [10] # In minutes +COMMON_SHAPE_TYPE = "dynamic" -if compile_dynamic_sendnn: + +if COMPILE_DYNAMIC_SHAPE: + import bisect + + # the compiler supports certain max context lengths (VLLM_DT_MAX_CONTEXT_LEN) + # this will ensure that we select smallest supported VLLM_DT_MAX_CONTEXT_LEN that fits the largest possible context (prompt size + max_new_tokens) + __largest_context = max(common_seq_lengths) + max(common_max_new_tokens) + __supported_context_lengths = [256, 512, 1024, 2048, 4096, 8192] os.environ["VLLM_DT_MAX_CONTEXT_LEN"] = str( - (((max(common_seq_lengths) + max(common_max_new_tokens)) // 64) + 1) * 64 + __supported_context_lengths[ + bisect.bisect_left(__supported_context_lengths, __largest_context) + ] ) os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str(max(max(common_batch_sizes), 2)) -common_shapes = list( +COMMON_SHAPES = list( zip( common_model_paths, - common_shape_types, common_batch_sizes, common_seq_lengths, common_max_new_tokens, @@ -59,7 +76,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0): SHARE_GPT_DATASET_PATH, batch_size, tokenizer, - int(seq_length / 2), + seq_length // 2, seq_length, seed, ) @@ -74,16 +91,18 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0): @pytest.fixture(autouse=True) def reset_compiler(): yield # run the test - if not compile_dynamic_sendnn: + if not COMPILE_DYNAMIC_SHAPE: torch.compiler.reset() torch._dynamo.reset() os.environ.pop("COMPILATION_MODE", None) @pytest.mark.parametrize( - "model_path,shape_type,batch_size,seq_length,max_new_tokens,expected_comp_time", common_shapes + "model_path,batch_size,seq_length,max_new_tokens,expected_comp_time", COMMON_SHAPES ) -def test_compilation_time(model_path, shape_type, batch_size, seq_length, max_new_tokens, expected_comp_time): +def test_compilation_time( + model_path, batch_size, seq_length, max_new_tokens, expected_comp_time +): """Test to validate time taken for model compilation.""" torch.manual_seed(42) torch.set_default_dtype(torch.float16) @@ -104,7 +123,7 @@ def test_compilation_time(model_path, shape_type, batch_size, seq_length, max_ne model = get_model( architecture="hf_pretrained", device_type="cpu", - data_type= torch.float16, + data_type=torch.float16, fused_weights=False, **model_path_kwargs, ) @@ -117,21 +136,14 @@ def test_compilation_time(model_path, shape_type, batch_size, seq_length, max_ne extra_kwargs["attn_name"] = ATTN_NAME start_time = time.perf_counter() - if shape_type == "dynamic": - compile_dynamic_sendnn = True + if COMMON_SHAPE_TYPE == "dynamic": + COMPILE_DYNAMIC_SHAPE = True else: - compile_dynamic_sendnn = False + COMPILE_DYNAMIC_SHAPE = False - model.compile( - backend="sendnn", options={"sendnn.dynamic": compile_dynamic_sendnn} - ) + model.compile(backend="sendnn", options={"sendnn.dynamic": COMPILE_DYNAMIC_SHAPE}) warmup_model( - model, - input_ids, - max_new_tokens, - compile_dynamic_sendnn, - use_cache=False, - **extra_kwargs + model, input_ids, max_new_tokens, COMPILE_DYNAMIC_SHAPE, **extra_kwargs ) end_time = time.perf_counter() From d680f6749e4d1b13fb88bb72632af340cb445b10 Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Wed, 30 Jul 2025 15:01:56 -0500 Subject: [PATCH 3/4] :memo: Add doc string for get_env_to_int_list function Signed-off-by: gkumbhat --- aiu_fms_testing_utils/utils/__init__.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/aiu_fms_testing_utils/utils/__init__.py b/aiu_fms_testing_utils/utils/__init__.py index e5721b68..c078d7ba 100644 --- a/aiu_fms_testing_utils/utils/__init__.py +++ b/aiu_fms_testing_utils/utils/__init__.py @@ -1,5 +1,5 @@ # Standard -from typing import Optional, List, Tuple +from typing import Any, Optional, List, Tuple import json import os import random @@ -67,9 +67,17 @@ def warmup_model( dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s") -def get_env_to_int_list(env_var_name, default): +def get_env_to_datatype_list(env_var_name: str, default: Any, data_type = int): """Utility function to convert list of strings passed as given environment variable to - list of integers + list of provided data_type (default = int) + Args: + env_var_name (str): The name of the environment variable to retrieve. + default (list or any): The default value to return if the environment variable is not set. + If a list is provided, it will be returned as is if the environment variable is not set. + data_type (type, optional): The data type to convert the string values to. Defaults to int. + + Returns: + list: A list of integers or the default value if the environment variable is not set or is an empty string. """ env_var_string = os.environ.get(env_var_name, default=default) if not env_var_string: @@ -77,7 +85,7 @@ def get_env_to_int_list(env_var_name, default): if isinstance(env_var_string, list): return env_var_string - return [int(v) for v in env_var_string.split(",") if not isinstance(v, int)] + return [data_type(v) for v in env_var_string.split(",") if not isinstance(v, data_type)] def ids_for_prompt(prompt, tokenizer): From 8640dc2e1e4578ae171a7e7eb67af7dc8ce9fd6d Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Fri, 1 Aug 2025 16:36:45 -0500 Subject: [PATCH 4/4] :construction::label: Work in progress add util function and types Signed-off-by: gkumbhat --- aiu_fms_testing_utils/utils/__init__.py | 4 +- aiu_fms_testing_utils/utils/paged.py | 7 +- tests/testing/test_compilation.py | 158 +++++++++++++++++++----- 3 files changed, 136 insertions(+), 33 deletions(-) diff --git a/aiu_fms_testing_utils/utils/__init__.py b/aiu_fms_testing_utils/utils/__init__.py index c078d7ba..bbc56765 100644 --- a/aiu_fms_testing_utils/utils/__init__.py +++ b/aiu_fms_testing_utils/utils/__init__.py @@ -67,7 +67,7 @@ def warmup_model( dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s") -def get_env_to_datatype_list(env_var_name: str, default: Any, data_type = int): +def get_env_to_datatype_list(env_var_name: str, default: Any, data_type = int) -> List[Any]: """Utility function to convert list of strings passed as given environment variable to list of provided data_type (default = int) Args: @@ -77,7 +77,7 @@ def get_env_to_datatype_list(env_var_name: str, default: Any, data_type = int): data_type (type, optional): The data type to convert the string values to. Defaults to int. Returns: - list: A list of integers or the default value if the environment variable is not set or is an empty string. + list: A list of given data_type or the default value if the environment variable is not set or is an empty string. """ env_var_string = os.environ.get(env_var_name, default=default) if not env_var_string: diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index 05d42e9f..81b2bf64 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -144,7 +144,7 @@ def generate( kvheads = kvheads // tensor_parallel_size if kvheads > 1 else kvheads head_size = model.config.emb_dim // nheads - if "fp8" in kwargs["attn_name"]: + if "fp8" in kwargs.get("attn_name", ""): from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor kwargs["past_key_value_states"] = [ @@ -262,8 +262,9 @@ def generate( # This view will result in a discontiguous tensor (creates a new graph during compile) # For this reason, we must explicitly make contiguous + # kwargs["mask"][seq_i][:, -current_tkv:, -current_tkv:] mask_i = ( - kwargs["mask"][seq_i][:, -current_tkv:, -current_tkv:] + kwargs["mask"][seq_i][:, -current_tkv:] .unsqueeze(0) .contiguous() ) @@ -364,7 +365,7 @@ def generate( torch._dynamo.mark_dynamic(kwargs["position_ids"], 0) torch._dynamo.mark_dynamic(kwargs["current_tkv_mask"], 0) torch._dynamo.mark_dynamic(kwargs["left_padded_prompt_mask"], 0) - if "fp8" in kwargs["attn_name"]: + if "fp8" in kwargs.get("attn_name", ""): for k_cache, v_cache in kwargs["past_key_value_states"]: torch._dynamo.mark_dynamic(k_cache._scale, 0) torch._dynamo.mark_dynamic(v_cache._scale, 0) diff --git a/tests/testing/test_compilation.py b/tests/testing/test_compilation.py index 9c4a3eae..72b12558 100644 --- a/tests/testing/test_compilation.py +++ b/tests/testing/test_compilation.py @@ -6,6 +6,7 @@ import time # Third Party +from transformers.tokenization_utils_base import BatchEncoding import torch # First Party @@ -16,7 +17,7 @@ # Local from aiu_fms_testing_utils.utils import ( ids_for_prompt, - get_env_to_int_list, + get_env_to_datatype_list, sample_sharegpt_requests, warmup_model, ) @@ -27,37 +28,33 @@ "SHARE_GPT_DATASET_PATH", os.path.expanduser("~/share_gpt.json") ) -ATTN_NAME = "spyre_paged_attn" +ATTN_TYPE = os.environ.get("FMS_TEST_SHAPES_ATTN_TYPE", "paged") +attention_map = { + "sdpa": "sdpa_causal", + "paged": "spyre_paged_attn", + "math_fp8": "math_fp8", + "paged_fp8": "spyre_paged_attn_fp8", +} +ATTN_NAME = attention_map[ATTN_TYPE] -COMPILE_DYNAMIC_SHAPE = True +COMPILE_DYNAMIC_SHAPE = os.environ.get("AIU_COMPILE_DYNAMIC_SHAPE", True) -common_model_paths = get_env_to_int_list("COMMON_MODEL_NAME", [GRANITE_3p3_8B_INSTRUCT]) -common_batch_sizes = get_env_to_int_list("FMS_TEST_SHAPES_COMMON_BATCH_SIZES", [1]) -common_seq_lengths = get_env_to_int_list("FMS_TEST_SHAPES_COMMON_SEQ_LENGTHS", [64]) -common_max_new_tokens = get_env_to_int_list( +common_model_paths = get_env_to_datatype_list("COMMON_MODEL_NAME", [GRANITE_3p3_8B_INSTRUCT]) +common_batch_sizes = get_env_to_datatype_list("FMS_TEST_SHAPES_COMMON_BATCH_SIZES", [1]) +common_seq_lengths = get_env_to_datatype_list("FMS_TEST_SHAPES_COMMON_SEQ_LENGTHS", [64]) +common_max_new_tokens = get_env_to_datatype_list( "FMS_TEST_SHAPES_COMMON_MAX_NEW_TOKENS", [64] ) -common_expected_comp_time = get_env_to_int_list( +common_expected_comp_time = get_env_to_datatype_list( "COMMON_COMPILATION_EXPECTED_TIME", [10] ) # In minutes -COMMON_SHAPE_TYPE = "dynamic" - if COMPILE_DYNAMIC_SHAPE: - import bisect - - # the compiler supports certain max context lengths (VLLM_DT_MAX_CONTEXT_LEN) - # this will ensure that we select smallest supported VLLM_DT_MAX_CONTEXT_LEN that fits the largest possible context (prompt size + max_new_tokens) - __largest_context = max(common_seq_lengths) + max(common_max_new_tokens) - __supported_context_lengths = [256, 512, 1024, 2048, 4096, 8192] - os.environ["VLLM_DT_MAX_CONTEXT_LEN"] = str( - __supported_context_lengths[ - bisect.bisect_left(__supported_context_lengths, __largest_context) - ] - ) - os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str(max(max(common_batch_sizes), 2)) + COMMON_SHAPE_TYPE = "dynamic" +else: + COMMON_SHAPE_TYPE = "static" COMMON_SHAPES = list( zip( @@ -70,7 +67,103 @@ ) -# TODO: This is copied from test_decoders.py would be good to consolidate +def __set_context_length(seq_len: int, max_new_tokens: int, batch_size: int) -> None: + """ + This function sets the environment variables for maximum context length and batch size. + + It calculates the largest context by adding the sequence length and the maximum number of new tokens. + + Args: + seq_len (int): The length of the input sequence. + max_new_tokens (int): The maximum number of new tokens to generate. + batch_size (int): The batch size for processing. + + This function sets the environment variables: + - VLLM_DT_MAX_CONTEXT_LEN: The selected maximum context length. + - VLLM_DT_MAX_BATCH_SIZE: The maximum batch size, with a minimum value of 2. + """ + largest_context = seq_len + max_new_tokens + os.environ["VLLM_DT_MAX_CONTEXT_LEN"] = str(largest_context) + os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str(max(batch_size, 2)) + +def __get_dummy_inputs(batch_size, seq_length, tokenizer): + """ + This function creates dummy input tensors for a given sequence length. + It uses the tokenizer to generate valid token IDs, excluding special + tokens (beginning-of-sequence and end-of-sequence). + + Args: + batch_size (int): The number of sequences in a batch. + seq_length (int): The length of each sequence. + tokenizer (Tokenizer): The tokenizer object used for tokenization. + + Returns: + Tuple(input_ids, attention_masks) + - input_ids (torch.Tensor): A tensor of shape (batch_size, seq_length) + containing randomly sampled valid token IDs. + - attention_masks (torch.Tensor): A tensor of shape (batch_size, seq_length) + filled with ones, indicating that all tokens are attended to. + """ + vocab_size = tokenizer.tokenizer.vocab_size + bos_token_id = tokenizer.bos_token_id + eos_token_id = tokenizer.eos_token_id + special_token_ids = [bos_token_id, eos_token_id] + # breakpoint() + valid_token_ids = [ + i for i in range(1, vocab_size) if i not in set(special_token_ids) + ] + valid_token_ids_tensor = torch.tensor(valid_token_ids, dtype=torch.long) + + # Sample from the valid token ids + input_ids = valid_token_ids_tensor[torch.randint( + 0, len(valid_token_ids_tensor), (batch_size, seq_length))] + + attention_masks = torch.ones((batch_size, seq_length)) + + position_ids = torch.arange(seq_length, dtype=torch.long).unsqueeze(0).expand(batch_size, -1) + + # breakpoint() + + batch = BatchEncoding({ + "input_ids": input_ids, + "mask": attention_masks, + "position_ids": position_ids, + }) + return batch + + + +# TODO: Move this function outside for re-usability in other places +def __generate(model, input_ids, max_new_tokens, **kwargs): + import torch_sendnn + + attention_specific_kwargs = {} + attn_name = kwargs.get("attn_name", "sdpa") + + extra_kwargs = kwargs.get("extra_kwargs", {}) + + if "paged" in attn_name: + from aiu_fms_testing_utils.utils.paged import generate, adjust_inputs_to_batch + input_ids, _extra_kwargs = adjust_inputs_to_batch(input_ids=input_ids, **extra_kwargs) + extra_kwargs = {**_extra_kwargs, "attn_name": attn_name} + else: + # TODO: Add a unified generation dependent on attn_type + from fms.utils.generation import generate + + attention_specific_kwargs["contiguous_cache"] = True + attention_specific_kwargs["max_seq_len"] = input_ids.shape[1] + max_new_tokens + extra_kwargs["only_last_token"] = True + + return generate( + model, + input_ids, + max_new_tokens=max_new_tokens, + do_sample=False, + extra_kwargs=extra_kwargs, + **attention_specific_kwargs, + ) + + def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0): prompts_and_sizes = sample_sharegpt_requests( SHARE_GPT_DATASET_PATH, @@ -107,9 +200,10 @@ def test_compilation_time( torch.manual_seed(42) torch.set_default_dtype(torch.float16) os.environ["COMPILATION_MODE"] = "offline_decoder" + __set_context_length(seq_length, max_new_tokens, batch_size) dprint( - f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}" + f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}, attn_type={ATTN_TYPE}", ) if os.path.exists(model_path): @@ -131,9 +225,10 @@ def test_compilation_time( model.eval() torch.set_grad_enabled(False) - # prepare input_ids - input_ids, extra_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer) - extra_kwargs["attn_name"] = ATTN_NAME + # prepare batch input + batch = __get_dummy_inputs(batch_size, seq_length, tokenizer) + # inputs, args = __prepare_inputs(batch_size, seq_length, tokenizer) + # breakpoint() start_time = time.perf_counter() if COMMON_SHAPE_TYPE == "dynamic": @@ -143,8 +238,15 @@ def test_compilation_time( model.compile(backend="sendnn", options={"sendnn.dynamic": COMPILE_DYNAMIC_SHAPE}) warmup_model( - model, input_ids, max_new_tokens, COMPILE_DYNAMIC_SHAPE, **extra_kwargs + model, batch["input_ids"], max_new_tokens, COMPILE_DYNAMIC_SHAPE, attn_name=ATTN_NAME, mask=batch["mask"], position_ids=batch["position_ids"] ) + extra_kwargs = { + "position_ids": batch["position_ids"], + "mask": batch["mask"], + "attn_name": ATTN_NAME + } + __generate(model, batch["input_ids"], max_new_tokens=max_new_tokens, use_cache=True, extra_kwargs=extra_kwargs) + end_time = time.perf_counter() assert (end_time - start_time) < expected_comp_time * 60 \ No newline at end of file