Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 31 additions & 15 deletions python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""A compiler pass that attaches two-stage softmax with temperature."""

from typing import Any, Dict, Optional

import tvm
from tvm import relax, tir
from tvm.ir.module import IRModule
Expand All @@ -13,21 +15,28 @@
class AttachSoftmaxWithTemperature: # pylint: disable=too-few-public-methods
"""Rewrites one-shot softmax into two-stage softmax."""

def __init__(self, target: tvm.target.Target) -> None:
def __init__(
self, target: tvm.target.Target, metadata: Optional[Dict[str, Any]] = None
) -> None:
self.target = target
self.metadata = metadata

def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
"""IRModule-level transformation"""
return _Rewriter(mod, self.target).transform()
return _Rewriter(mod, self.target, self.metadata).transform()


@mutator
class _Rewriter(PyExprMutator): # pylint: disable=abstract-method
def __init__(self, mod: IRModule, target: tvm.target.Target) -> None:
def __init__(
self, mod: IRModule, target: tvm.target.Target, metadata: Optional[Dict[str, Any]] = None
) -> None:
super().__init__(mod)
self.mod = mod
self.target = target
self.metadata = metadata
self.chunk_size = 4096
self.active_vocab_size = self.metadata.get("active_vocab_size") if self.metadata else None

def transform(self) -> IRModule:
"""Entry point"""
Expand All @@ -47,7 +56,7 @@ def transform(self) -> IRModule:
sinfo_args=relax.TensorStructInfo(new_shape, dtype),
)
f_chunk_lse, f_softmax_with_lse = _get_lse_and_softmax_func(
self.target, self.chunk_size
self.target, self.chunk_size, self.active_vocab_size
)
chunked_result_struct_info = relax.TensorStructInfo(
(batch_size, (vocab_size + self.chunk_size - 1) // self.chunk_size),
Expand Down Expand Up @@ -82,7 +91,7 @@ def transform(self) -> IRModule:


def _get_lse_and_softmax_func( # pylint: disable=too-many-locals,too-many-statements
target: tvm.target.Target, chunk_size: int
target: tvm.target.Target, chunk_size: int, active_vocab_size: int
):
# NOTE: A quick note on the softmax implementation.
# We once tried to multiply every element by log2e which can be computed
Expand Down Expand Up @@ -124,8 +133,9 @@ def chunk_lse( # pylint: disable=too-many-locals
for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)):
with T.block("pad"):
v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
A_pad[v0, v1, v2] = T.if_then_else(
v1 * T.int64(chunk_size) + v2 < vocab_size,
A_pad[v0, v1, v2] = T.Select(
v1 * T.int64(chunk_size) + v2
< (active_vocab_size if active_vocab_size is not None else vocab_size),
T.if_then_else(
temperature[v0] > T.float32(1e-5),
A[v0, v1 * T.int64(chunk_size) + v2] / temperature[v0],
Expand All @@ -145,7 +155,8 @@ def chunk_lse( # pylint: disable=too-many-locals
with T.init():
temp_sum[v0, v1] = T.float32(0)
temp_sum[v0, v1] += T.if_then_else(
v1 * T.int64(chunk_size) + v2 < vocab_size,
v1 * T.int64(chunk_size) + v2
< (active_vocab_size if active_vocab_size is not None else vocab_size),
T.Select(
temperature[v0] > T.float32(1e-5),
T.exp(A_pad[v0, v1, v2] - temp_max[v0, v1]),
Expand Down Expand Up @@ -202,14 +213,19 @@ def softmax_with_chunked_sum(
with T.block("log_pad"):
v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
if v1 * T.int64(chunk_size) + v2 < vocab_size:
softmax[v0, v1 * T.int64(chunk_size) + v2] = T.if_then_else(
temperature[v0] > T.float32(1e-5),
T.exp(
A[v0, v1 * T.int64(chunk_size) + v2] / temperature[v0]
- (T.log(temp_sum[v0]) + temp_max[v0])
softmax[v0, v1 * T.int64(chunk_size) + v2] = T.Select(
v1 * T.int64(chunk_size) + v2
< (active_vocab_size if active_vocab_size is not None else vocab_size),
T.if_then_else(
temperature[v0] > T.float32(1e-5),
T.exp(
A[v0, v1 * T.int64(chunk_size) + v2] / temperature[v0]
- (T.log(temp_sum[v0]) + temp_max[v0])
),
T.cast(A[v0, v1 * T.int64(chunk_size) + v2] == temp_max[v0], "float32")
/ temp_sum[v0],
),
T.cast(A[v0, v1 * T.int64(chunk_size) + v2] == temp_max[v0], "float32")
/ temp_sum[v0],
T.float32(0),
)

sch = tvm.tir.Schedule(IRModule({"softmax_with_chunked_sum": softmax_with_chunked_sum}))
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/compiler_pass/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
[
# Phase 0. Add additional information for compilation and remove unused Relax func
DispatchKVCacheCreation(target, flashinfer, metadata),
AttachSoftmaxWithTemperature(target),
AttachSoftmaxWithTemperature(target, metadata),
AttachVariableBounds(variable_bounds),
AttachCUDAGraphSymbolicCaptureHints(cuda_graph_symbolic_capture_hints),
AttachPipelineParallelStages(metadata["pipeline_parallel_stages"]),
Expand Down
12 changes: 11 additions & 1 deletion python/mlc_llm/interface/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def _get_param_metadata(name: str, param: nn.Parameter) -> Dict[str, Any]:
"pipeline_stages": param.attrs.get("pipeline_stages", [0]),
}

logger.info("TOP LEVEL MODEL CONFIG BEFORE OVERRIDES: %s", str(model_config))
_kwargs = getattr(model_config, "kwargs", {})
model_config = args.overrides.apply(model_config)
Comment on lines +131 to 133
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note here, i noticed that this override wipes out any kwargs in the original model_config. This PR isn't the place to address it probably, but I just wanted to call it out.

with args.target:
op_ext.enable(
Expand Down Expand Up @@ -170,6 +172,9 @@ def _get_param_metadata(name: str, param: nn.Parameter) -> Dict[str, Any]:
"batch_verify": ["batch_size", "seq_len"],
"batch_verify_to_last_hidden_states": ["batch_size", "seq_len"],
}
avs = _kwargs.get("active_vocab_size", None)
if avs is not None and avs <= 0:
avs = None
metadata = {
"model_type": args.model.name,
"quantization": args.quantization.name,
Expand All @@ -182,6 +187,7 @@ def _get_param_metadata(name: str, param: nn.Parameter) -> Dict[str, Any]:
"disaggregation": getattr(model_config, "disaggregation", False),
"kv_state_kind": _infer_kv_state_kind(args.model.name),
"max_batch_size": getattr(model_config, "max_batch_size", 1),
"active_vocab_size": avs,
}
logger.info("Registering metadata: %s", metadata)
metadata["params"] = [_get_param_metadata(name, param) for name, param in named_params]
Expand Down Expand Up @@ -221,13 +227,17 @@ def compile( # pylint: disable=too-many-arguments,redefined-builtin
debug_dump: Optional[Path] = None,
):
"""Compile a model given its configuration and quantization format to a specific target."""
avs = None
if "active_vocab_size" in config:
avs = config.pop("active_vocab_size")
logger.info("Active vocab size from input config: %s", str(avs))
if "model_config" in config:
model_config = config.pop("model_config")
model_config.update(config)
model_config = model_type.config.from_dict(model_config)
else:
model_config = model_type.config.from_dict(config)
model_config.kwargs = {}
model_config.kwargs = {"active_vocab_size": avs} if avs is not None else {}
args = CompileArgs(
model_config,
quantization,
Expand Down
25 changes: 25 additions & 0 deletions python/mlc_llm/interface/gen_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
quantization=quantization.name,
model_config=model_config.asdict(),
vocab_size=model_config.vocab_size,
active_vocab_size=getattr(model_config, "active_vocab_size", model_config.vocab_size),
context_window_size=getattr(model_config, "context_window_size", -1),
sliding_window_size=getattr(model_config, "sliding_window_size", -1),
prefill_chunk_size=model_config.prefill_chunk_size,
Expand Down Expand Up @@ -245,6 +246,30 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b

# Step 4. Load system default value
apply_system_defaults_for_missing_fields(mlc_chat_config)

# Step 5. Use HF tokenizer to detect active vocab size via len(tokenizer)
if tokenizer_json_file.exists():
try:
from transformers import ( # pylint: disable=import-error,import-outside-toplevel
AutoTokenizer,
)

hf_tokenizer = AutoTokenizer.from_pretrained(str(config.parent), use_fast=True)
active_vocab_size = len(hf_tokenizer)
if mlc_chat_config.active_vocab_size != active_vocab_size:
logger.info(
"Overriding active_vocab_size from %d to %d using HF tokenizer",
mlc_chat_config.active_vocab_size,
active_vocab_size,
)
mlc_chat_config.active_vocab_size = active_vocab_size
except Exception: # pylint: disable=broad-exception-caught
logger.warning(
"Detecting active_vocab_size %s with the exception below. Skipping.",
FAILED,
exc_info=True,
)

# Step 5. Dump the configuration file to output directory
with (output / "mlc-chat-config.json").open("w", encoding="utf-8") as out_file:
json.dump(mlc_chat_config.model_dump(by_alias=True), out_file, indent=2)
Expand Down
1 change: 1 addition & 0 deletions python/mlc_llm/protocol/mlc_chat_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class MLCChatConfig(BaseModel):
# use alias to avoid protected namespace conflict with pydantic
field_model_config: Dict[str, Any] = Field(alias="model_config")
vocab_size: int
active_vocab_size: int
context_window_size: int
sliding_window_size: int
prefill_chunk_size: int
Expand Down