diff --git a/python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py b/python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py index 5175f8e72b..05365dca84 100644 --- a/python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py +++ b/python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py @@ -1,5 +1,7 @@ """A compiler pass that attaches two-stage softmax with temperature.""" +from typing import Any, Dict, Optional + import tvm from tvm import relax, tir from tvm.ir.module import IRModule @@ -13,21 +15,28 @@ class AttachSoftmaxWithTemperature: # pylint: disable=too-few-public-methods """Rewrites one-shot softmax into two-stage softmax.""" - def __init__(self, target: tvm.target.Target) -> None: + def __init__( + self, target: tvm.target.Target, metadata: Optional[Dict[str, Any]] = None + ) -> None: self.target = target + self.metadata = metadata def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: """IRModule-level transformation""" - return _Rewriter(mod, self.target).transform() + return _Rewriter(mod, self.target, self.metadata).transform() @mutator class _Rewriter(PyExprMutator): # pylint: disable=abstract-method - def __init__(self, mod: IRModule, target: tvm.target.Target) -> None: + def __init__( + self, mod: IRModule, target: tvm.target.Target, metadata: Optional[Dict[str, Any]] = None + ) -> None: super().__init__(mod) self.mod = mod self.target = target + self.metadata = metadata self.chunk_size = 4096 + self.active_vocab_size = self.metadata.get("active_vocab_size") if self.metadata else None def transform(self) -> IRModule: """Entry point""" @@ -47,7 +56,7 @@ def transform(self) -> IRModule: sinfo_args=relax.TensorStructInfo(new_shape, dtype), ) f_chunk_lse, f_softmax_with_lse = _get_lse_and_softmax_func( - self.target, self.chunk_size + self.target, self.chunk_size, self.active_vocab_size ) chunked_result_struct_info = relax.TensorStructInfo( (batch_size, (vocab_size + self.chunk_size - 1) // self.chunk_size), @@ -82,7 +91,7 @@ def transform(self) -> IRModule: def _get_lse_and_softmax_func( # pylint: disable=too-many-locals,too-many-statements - target: tvm.target.Target, chunk_size: int + target: tvm.target.Target, chunk_size: int, active_vocab_size: int ): # NOTE: A quick note on the softmax implementation. # We once tried to multiply every element by log2e which can be computed @@ -124,8 +133,9 @@ def chunk_lse( # pylint: disable=too-many-locals for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)): with T.block("pad"): v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) - A_pad[v0, v1, v2] = T.if_then_else( - v1 * T.int64(chunk_size) + v2 < vocab_size, + A_pad[v0, v1, v2] = T.Select( + v1 * T.int64(chunk_size) + v2 + < (active_vocab_size if active_vocab_size is not None else vocab_size), T.if_then_else( temperature[v0] > T.float32(1e-5), A[v0, v1 * T.int64(chunk_size) + v2] / temperature[v0], @@ -145,7 +155,8 @@ def chunk_lse( # pylint: disable=too-many-locals with T.init(): temp_sum[v0, v1] = T.float32(0) temp_sum[v0, v1] += T.if_then_else( - v1 * T.int64(chunk_size) + v2 < vocab_size, + v1 * T.int64(chunk_size) + v2 + < (active_vocab_size if active_vocab_size is not None else vocab_size), T.Select( temperature[v0] > T.float32(1e-5), T.exp(A_pad[v0, v1, v2] - temp_max[v0, v1]), @@ -202,14 +213,19 @@ def softmax_with_chunked_sum( with T.block("log_pad"): v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) if v1 * T.int64(chunk_size) + v2 < vocab_size: - softmax[v0, v1 * T.int64(chunk_size) + v2] = T.if_then_else( - temperature[v0] > T.float32(1e-5), - T.exp( - A[v0, v1 * T.int64(chunk_size) + v2] / temperature[v0] - - (T.log(temp_sum[v0]) + temp_max[v0]) + softmax[v0, v1 * T.int64(chunk_size) + v2] = T.Select( + v1 * T.int64(chunk_size) + v2 + < (active_vocab_size if active_vocab_size is not None else vocab_size), + T.if_then_else( + temperature[v0] > T.float32(1e-5), + T.exp( + A[v0, v1 * T.int64(chunk_size) + v2] / temperature[v0] + - (T.log(temp_sum[v0]) + temp_max[v0]) + ), + T.cast(A[v0, v1 * T.int64(chunk_size) + v2] == temp_max[v0], "float32") + / temp_sum[v0], ), - T.cast(A[v0, v1 * T.int64(chunk_size) + v2] == temp_max[v0], "float32") - / temp_sum[v0], + T.float32(0), ) sch = tvm.tir.Schedule(IRModule({"softmax_with_chunked_sum": softmax_with_chunked_sum})) diff --git a/python/mlc_llm/compiler_pass/pipeline.py b/python/mlc_llm/compiler_pass/pipeline.py index 8618af4bd7..c8f4edb627 100644 --- a/python/mlc_llm/compiler_pass/pipeline.py +++ b/python/mlc_llm/compiler_pass/pipeline.py @@ -105,7 +105,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I [ # Phase 0. Add additional information for compilation and remove unused Relax func DispatchKVCacheCreation(target, flashinfer, metadata), - AttachSoftmaxWithTemperature(target), + AttachSoftmaxWithTemperature(target, metadata), AttachVariableBounds(variable_bounds), AttachCUDAGraphSymbolicCaptureHints(cuda_graph_symbolic_capture_hints), AttachPipelineParallelStages(metadata["pipeline_parallel_stages"]), diff --git a/python/mlc_llm/interface/compile.py b/python/mlc_llm/interface/compile.py index cf04ee43db..5ea4e5e8b6 100644 --- a/python/mlc_llm/interface/compile.py +++ b/python/mlc_llm/interface/compile.py @@ -128,6 +128,8 @@ def _get_param_metadata(name: str, param: nn.Parameter) -> Dict[str, Any]: "pipeline_stages": param.attrs.get("pipeline_stages", [0]), } + logger.info("TOP LEVEL MODEL CONFIG BEFORE OVERRIDES: %s", str(model_config)) + _kwargs = getattr(model_config, "kwargs", {}) model_config = args.overrides.apply(model_config) with args.target: op_ext.enable( @@ -170,6 +172,9 @@ def _get_param_metadata(name: str, param: nn.Parameter) -> Dict[str, Any]: "batch_verify": ["batch_size", "seq_len"], "batch_verify_to_last_hidden_states": ["batch_size", "seq_len"], } + avs = _kwargs.get("active_vocab_size", None) + if avs is not None and avs <= 0: + avs = None metadata = { "model_type": args.model.name, "quantization": args.quantization.name, @@ -182,6 +187,7 @@ def _get_param_metadata(name: str, param: nn.Parameter) -> Dict[str, Any]: "disaggregation": getattr(model_config, "disaggregation", False), "kv_state_kind": _infer_kv_state_kind(args.model.name), "max_batch_size": getattr(model_config, "max_batch_size", 1), + "active_vocab_size": avs, } logger.info("Registering metadata: %s", metadata) metadata["params"] = [_get_param_metadata(name, param) for name, param in named_params] @@ -221,13 +227,17 @@ def compile( # pylint: disable=too-many-arguments,redefined-builtin debug_dump: Optional[Path] = None, ): """Compile a model given its configuration and quantization format to a specific target.""" + avs = None + if "active_vocab_size" in config: + avs = config.pop("active_vocab_size") + logger.info("Active vocab size from input config: %s", str(avs)) if "model_config" in config: model_config = config.pop("model_config") model_config.update(config) model_config = model_type.config.from_dict(model_config) else: model_config = model_type.config.from_dict(config) - model_config.kwargs = {} + model_config.kwargs = {"active_vocab_size": avs} if avs is not None else {} args = CompileArgs( model_config, quantization, diff --git a/python/mlc_llm/interface/gen_config.py b/python/mlc_llm/interface/gen_config.py index a526f2a56d..af24afbd9a 100644 --- a/python/mlc_llm/interface/gen_config.py +++ b/python/mlc_llm/interface/gen_config.py @@ -129,6 +129,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b quantization=quantization.name, model_config=model_config.asdict(), vocab_size=model_config.vocab_size, + active_vocab_size=getattr(model_config, "active_vocab_size", model_config.vocab_size), context_window_size=getattr(model_config, "context_window_size", -1), sliding_window_size=getattr(model_config, "sliding_window_size", -1), prefill_chunk_size=model_config.prefill_chunk_size, @@ -245,6 +246,30 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b # Step 4. Load system default value apply_system_defaults_for_missing_fields(mlc_chat_config) + + # Step 5. Use HF tokenizer to detect active vocab size via len(tokenizer) + if tokenizer_json_file.exists(): + try: + from transformers import ( # pylint: disable=import-error,import-outside-toplevel + AutoTokenizer, + ) + + hf_tokenizer = AutoTokenizer.from_pretrained(str(config.parent), use_fast=True) + active_vocab_size = len(hf_tokenizer) + if mlc_chat_config.active_vocab_size != active_vocab_size: + logger.info( + "Overriding active_vocab_size from %d to %d using HF tokenizer", + mlc_chat_config.active_vocab_size, + active_vocab_size, + ) + mlc_chat_config.active_vocab_size = active_vocab_size + except Exception: # pylint: disable=broad-exception-caught + logger.warning( + "Detecting active_vocab_size %s with the exception below. Skipping.", + FAILED, + exc_info=True, + ) + # Step 5. Dump the configuration file to output directory with (output / "mlc-chat-config.json").open("w", encoding="utf-8") as out_file: json.dump(mlc_chat_config.model_dump(by_alias=True), out_file, indent=2) diff --git a/python/mlc_llm/protocol/mlc_chat_config.py b/python/mlc_llm/protocol/mlc_chat_config.py index 28b1df0572..cf79a57a30 100644 --- a/python/mlc_llm/protocol/mlc_chat_config.py +++ b/python/mlc_llm/protocol/mlc_chat_config.py @@ -33,6 +33,7 @@ class MLCChatConfig(BaseModel): # use alias to avoid protected namespace conflict with pydantic field_model_config: Dict[str, Any] = Field(alias="model_config") vocab_size: int + active_vocab_size: int context_window_size: int sliding_window_size: int prefill_chunk_size: int