Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions mlc_llm/relax_model/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,16 +284,18 @@ def forward(
k_cache, v_cache = past_key_value
f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append")
k_cache = nn.emit(
relax.Call(
relax.op.call_inplace_packed(
f_kv_cache_append,
args=[k_cache, squeezed_k],
inplace_indices=[0],
sinfo_args=[relax.ObjectStructInfo()],
)
)
v_cache = nn.emit(
relax.Call(
relax.op.call_inplace_packed(
f_kv_cache_append,
args=[v_cache, squeezed_v],
inplace_indices=[0],
sinfo_args=[relax.ObjectStructInfo()],
)
)
Expand All @@ -304,14 +306,14 @@ def forward(
kv_cache_shape = R.shape([kv_sl, n_groups, head_dim])
f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view")
k = nn.emit(
relax.Call(
relax.call_pure_packed(
f_kv_cache_view,
args=[k_cache, kv_cache_shape],
sinfo_args=[R.Tensor(kv_cache_shape, k.struct_info.dtype)],
)
)
v = nn.emit(
relax.Call(
relax.call_pure_packed(
f_kv_cache_view,
args=[v_cache, kv_cache_shape],
sinfo_args=[R.Tensor(kv_cache_shape, v.struct_info.dtype)],
Expand Down Expand Up @@ -703,7 +705,7 @@ def create_kv_cache_func(bb: relax.BlockBuilder, config: ChatGLMConfig) -> None:
for _ in range(config.num_layers * 2):
caches.append(
bb.emit(
relax.Call(
relax.call_pure_packed(
f_kv_cache_create,
args=[zeros, init_shape, relax.PrimValue(0)],
sinfo_args=[relax.ObjectStructInfo()],
Expand Down
12 changes: 7 additions & 5 deletions mlc_llm/relax_model/gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,16 +221,18 @@ def te_slice(x: te.Tensor, start: int, end: int):
k_cache, v_cache = past_key_value
f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append")
k_cache = nn.emit(
relax.Call(
relax.op.call_inplace_packed(
f_kv_cache_append,
args=[k_cache, squeezed_k],
inplace_indices=[0],
sinfo_args=[relax.ObjectStructInfo()],
)
)
v_cache = nn.emit(
relax.Call(
relax.op.call_inplace_packed(
f_kv_cache_append,
args=[v_cache, squeezed_v],
inplace_indices=[0],
sinfo_args=[relax.ObjectStructInfo()],
)
)
Expand All @@ -241,14 +243,14 @@ def te_slice(x: te.Tensor, start: int, end: int):
kv_states_shape = R.shape([batch_size, kv_seq_len, head_size])
f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view")
k = nn.emit(
relax.Call(
relax.call_pure_packed(
f_kv_cache_view,
args=[k_cache, kv_cache_shape],
sinfo_args=[R.Tensor(kv_cache_shape, k.struct_info.dtype)],
)
)
v = nn.emit(
relax.Call(
relax.call_pure_packed(
f_kv_cache_view,
args=[v_cache, kv_cache_shape],
sinfo_args=[R.Tensor(kv_cache_shape, v.struct_info.dtype)],
Expand Down Expand Up @@ -576,7 +578,7 @@ def create_kv_cache_func(bb: relax.BlockBuilder, config: GPTBigCodeConfig) -> No
for _ in range(config.n_layer * 2):
caches.append(
bb.emit(
relax.Call(
relax.call_pure_packed(
f_kv_cache_create,
args=[zeros, init_shape, relax.PrimValue(0)],
sinfo_args=[relax.ObjectStructInfo()],
Expand Down
12 changes: 7 additions & 5 deletions mlc_llm/relax_model/gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,31 +114,33 @@ def forward(
f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view")
k_cache, v_cache = past_key_value
k_cache = nn.emit(
relax.Call(
relax.op.call_inplace_packed(
f_kv_cache_append,
args=[k_cache, squeeze(k, axis=0)],
inplace_indices=[0],
sinfo_args=[relax.ObjectStructInfo()],
)
)
v_cache = nn.emit(
relax.Call(
relax.op.call_inplace_packed(
f_kv_cache_append,
args=[v_cache, squeeze(v, axis=0)],
inplace_indices=[0],
sinfo_args=[relax.ObjectStructInfo()],
)
)
batch_size, _, num_heads, head_size = k.struct_info.shape
kv_cache_shape = R.shape([kv_seq_len, num_heads, head_size])
kv_states_shape = R.shape([batch_size, kv_seq_len, num_heads, head_size])
k = nn.emit(
relax.Call(
relax.call_pure_packed(
f_kv_cache_view,
args=[k_cache, kv_cache_shape],
sinfo_args=[R.Tensor(kv_cache_shape, k.struct_info.dtype)],
)
)
v = nn.emit(
relax.Call(
relax.call_pure_packed(
f_kv_cache_view,
args=[v_cache, kv_cache_shape],
sinfo_args=[R.Tensor(kv_cache_shape, v.struct_info.dtype)],
Expand Down Expand Up @@ -631,7 +633,7 @@ def create_kv_cache_func(
for _ in range(config.num_hidden_layers * 2):
caches.append(
bb.emit(
relax.Call(
relax.call_pure_packed(
f_kv_cache_create,
args=[zeros, init_shape, relax.PrimValue(0)],
sinfo_args=[relax.ObjectStructInfo()],
Expand Down
10 changes: 6 additions & 4 deletions mlc_llm/relax_model/gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,31 +153,33 @@ def _project(proj):
f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view")
k_cache, v_cache = past_key_value
k_cache = nn.emit(
relax.Call(
relax.op.call_inplace_packed(
f_kv_cache_append,
args=[k_cache, squeeze(k, axis=0)],
inplace_indices=[0],
sinfo_args=[relax.ObjectStructInfo()],
)
)
v_cache = nn.emit(
relax.Call(
relax.op.call_inplace_packed(
f_kv_cache_append,
args=[v_cache, squeeze(v, axis=0)],
inplace_indices=[0],
sinfo_args=[relax.ObjectStructInfo()],
)
)
batch_size, _, num_heads, head_size = k.struct_info.shape
kv_cache_shape = R.shape([kv_seq_len, num_heads, head_size])
kv_states_shape = R.shape([batch_size, kv_seq_len, num_heads, head_size])
k = nn.emit(
relax.Call(
relax.call_pure_packed(
f_kv_cache_view,
args=[k_cache, kv_cache_shape],
sinfo_args=[R.Tensor(kv_cache_shape, k.struct_info.dtype)],
)
)
v = nn.emit(
relax.Call(
relax.call_pure_packed(
f_kv_cache_view,
args=[v_cache, kv_cache_shape],
sinfo_args=[R.Tensor(kv_cache_shape, v.struct_info.dtype)],
Expand Down
14 changes: 8 additions & 6 deletions mlc_llm/relax_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,30 +463,32 @@ def attention_fwd(
k_cache, v_cache = past_key_values
f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append")
k_cache = nn.emit(
relax.Call(
relax.op.call_inplace_packed(
f_kv_cache_append,
args=[k_cache, squeezed_key],
inplace_indices=[0],
sinfo_args=[relax.ObjectStructInfo()],
)
)
v_cache = nn.emit(
relax.Call(
relax.op.call_inplace_packed(
f_kv_cache_append,
args=[v_cache, squeezed_value],
inplace_indices=[0],
sinfo_args=[relax.ObjectStructInfo()],
)
)
past_key_values = (k_cache, v_cache)
f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view")
k_cache = nn.emit(
relax.Call(
relax.call_pure_packed(
f_kv_cache_view,
args=[k_cache, kv_cache_shape],
sinfo_args=[R.Tensor(kv_cache_shape, kv_states_dtype)],
)
)
v_cache = nn.emit(
relax.Call(
relax.call_pure_packed(
f_kv_cache_view,
args=[v_cache, kv_cache_shape],
sinfo_args=[R.Tensor(kv_cache_shape, kv_states_dtype)],
Expand Down Expand Up @@ -1085,7 +1087,7 @@ def create_kv_cache_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None:
for _ in range(config.num_hidden_layers * 2):
caches.append(
bb.emit(
relax.Call(
relax.call_pure_packed(
f_kv_cache_create,
args=[zeros, init_shape, relax.PrimValue(0)],
sinfo_args=[relax.ObjectStructInfo()],
Expand Down Expand Up @@ -1114,7 +1116,7 @@ def create_paged_kv_cache_func(bb: relax.BlockBuilder, config: LlamaConfig) -> N
zeros = bb.emit(relax.op.zeros((), config.dtype))
f_kv_cache_create = relax.extern("vm.builtin.paged_attention_kv_cache_create")
cache = bb.emit_output(
relax.Call(
relax.call_pure_packed(
f_kv_cache_create,
args=[
cache_config,
Expand Down
11 changes: 4 additions & 7 deletions mlc_llm/relax_model/llama_batched_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,12 @@ def forward(
values_to_cache = nn.emit(take(values, indices_within_window, axis=0))
slot_mapping = nn.emit(take(slot_mapping, indices_within_window, axis=0))

# kv caches are updated inplace, but make it look like a pure operation
# kv caches are updated inplace, takes ownership of the arguments
kv = nn.emit(
relax.op.call_pure_packed(
relax.op.call_inplace_packed(
"tvm.contrib.vllm.reshape_and_cache",
keys_to_cache,
values_to_cache,
k_cache,
v_cache,
slot_mapping,
args=[keys_to_cache, values_to_cache, k_cache, v_cache, slot_mapping],
inplace_indices=[2, 3],
sinfo_args=[k_cache.struct_info, v_cache.struct_info],
)
)
Expand Down
36 changes: 21 additions & 15 deletions mlc_llm/relax_model/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,14 +343,14 @@ def te_squeeze(x):

f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view")
key_cached = nn.emit(
relax.Call(
relax.call_pure_packed(
f_kv_cache_view,
args=[k_cache, kv_cache_shape],
sinfo_args=[R.Tensor(kv_cache_shape, kv_cur_dtype)],
)
)
value_cached = nn.emit(
relax.Call(
relax.call_pure_packed(
f_kv_cache_view,
args=[v_cache, kv_cache_shape],
sinfo_args=[R.Tensor(kv_cache_shape, kv_cur_dtype)],
Expand Down Expand Up @@ -400,26 +400,28 @@ def te_squeeze(x):
"vm.builtin.attention_kv_cache_window_override_with_sinks"
)
k_cache = nn.emit(
relax.Call(
relax.op.call_inplace_packed(
f_kv_cache_override,
args=[
k_cache,
squeezed_key,
relax.PrimValue(self.sliding_window),
relax.PrimValue(attention_sink_size),
],
inplace_indices=[0],
sinfo_args=[relax.ObjectStructInfo()],
)
)
v_cache = nn.emit(
relax.Call(
relax.op.call_inplace_packed(
f_kv_cache_override,
args=[
v_cache,
squeezed_value,
relax.PrimValue(self.sliding_window),
relax.PrimValue(attention_sink_size),
],
inplace_indices=[0],
sinfo_args=[relax.ObjectStructInfo()],
)
)
Expand Down Expand Up @@ -664,7 +666,9 @@ def forward(self, input_ids: relax.Expr):


class MistralModel(nn.Module):
def __init__(self, config: MistralConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool = False):
def __init__(
self, config: MistralConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool = False
):
self.num_shards = config.num_shards
self.padding_idx = config.pad_token_id
self.embed_tokens = None
Expand Down Expand Up @@ -730,7 +734,9 @@ def forward(


class MistralForCausalLM(nn.Module):
def __init__(self, config: MistralConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool = False):
def __init__(
self, config: MistralConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool = False
):
self.model = MistralModel(config, vocab_size_var, sep_embed)
self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False)

Expand Down Expand Up @@ -827,13 +833,13 @@ def create_encoding_func(

bsz = 1
seq_len = tvm.tir.SizeVar("n", "int64") # number of tokens for the input
rolling_cache_len = tvm.tir.SizeVar("c", "int64") # rolling_cache_len captures number of elements in the cache
rolling_cache_len = tvm.tir.SizeVar(
"c", "int64"
) # rolling_cache_len captures number of elements in the cache
kv_seq_len = tvm.tir.SizeVar(
"k", "int64"
) # kv_seq_len captures number of elements in cache + seq_len
cache_offset = tvm.tir.SizeVar(
"o", "int64"
) # slidinf window kv cache offset
cache_offset = tvm.tir.SizeVar("o", "int64") # slidinf window kv cache offset

hidden_size = config.hidden_size
with bb.function(func_name):
Expand Down Expand Up @@ -888,13 +894,13 @@ def create_decoding_func(
func_name = "decode"

bsz = 1
rolling_cache_len = tvm.tir.SizeVar("c", "int64") # rolling_cache_len captures number of elements in the cache
rolling_cache_len = tvm.tir.SizeVar(
"c", "int64"
) # rolling_cache_len captures number of elements in the cache
kv_seq_len = tvm.tir.SizeVar(
"k", "int64"
) # kv_seq_len captures number of elements in cache + seq_len
cache_offset = tvm.tir.SizeVar(
"o", "int64"
) # sliding window kv cache offset
cache_offset = tvm.tir.SizeVar("o", "int64") # sliding window kv cache offset

with bb.function(func_name):
model = MistralForCausalLM(config, tvm.tir.SizeVar("vocab_size", "int64"))
Expand Down Expand Up @@ -952,7 +958,7 @@ def create_kv_cache_func(bb: relax.BlockBuilder, config: MistralConfig) -> None:
for _ in range(config.num_hidden_layers * 2):
caches.append(
bb.emit(
relax.Call(
relax.call_pure_packed(
f_kv_cache_create,
args=[zeros, init_shape, relax.PrimValue(0)],
sinfo_args=[relax.ObjectStructInfo()],
Expand Down
Loading