diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 642cdb8773..14fdad79a2 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -358,7 +358,7 @@ def mod_transform_before_build( ) # pylint: disable=not-callable if "num_attention_heads" in config and "hidden_size" in config: - if args.max_seq_len: + if args.max_seq_len != -1: mod = fuse_split_rotary_embedding(mod, config["num_attention_heads"], config["hidden_size"], args.max_seq_len) else: mod = fuse_split_rotary_embedding(mod, config["num_attention_heads"], config["hidden_size"]) diff --git a/mlc_llm/transform/fuse_split_rotary_embedding.py b/mlc_llm/transform/fuse_split_rotary_embedding.py index 95bc667a47..e2dff432df 100644 --- a/mlc_llm/transform/fuse_split_rotary_embedding.py +++ b/mlc_llm/transform/fuse_split_rotary_embedding.py @@ -77,7 +77,6 @@ def split_rotary( def fuse_split_rotary_embedding(mod, num_attention_heads, hidden_size, max_sequence_length=2048): head_dim = hidden_size // num_attention_heads - print(f"fuse_split_rotary_embedding {max_sequence_length}") mod["split_rotary"] = get_split_rotary(num_attention_heads, head_dim, max_sequence_length)