Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 2 additions & 15 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import mlc_llm
from mlc_llm import utils
from mlc_llm.relax_model import gpt_neox, llama, minigpt, moss, rwkv
from mlc_llm.relax_model import gpt_neox, llama, moss, rwkv


def _parse_args():
Expand Down Expand Up @@ -266,25 +266,14 @@ def mod_transform_before_build(
args: argparse.Namespace,
) -> tvm.IRModule:
"""First-stage: Legalize ops and trace"""
if args.model.startswith("rwkv-"):
if ARGS.model.startswith("rwkv-"):
model_names = [
"decode",
"create_kv_cache",
"softmax_with_temperature",
"get_metadata",
"reset_kv_cache",
]
elif args.model.startswith("minigpt4-"):
model_names = ["embed"]
elif args.model_category == "llama":
model_names = [
"embed",
"prefill",
"decode",
"create_kv_cache",
"softmax_with_temperature",
"get_metadata",
]
else:
model_names = [
"prefill",
Expand Down Expand Up @@ -416,8 +405,6 @@ def main():
mod, params = moss.get_model(ARGS, config)
elif ARGS.model_category == "rwkv":
mod, params = rwkv.get_model(ARGS, config)
elif ARGS.model_category == "minigpt":
mod, params = minigpt.get_model(ARGS)
else:
raise ValueError(f"Model {ARGS.model} not supported")
mod = mod_transform_before_build(mod, params, ARGS)
Expand Down
36 changes: 1 addition & 35 deletions cpp/cli_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,7 @@ class ChatModule {
*/
explicit ChatModule(const DLDevice& device) {
this->chat_mod_ = mlc::llm::CreateChatModule(device);
this->has_embed_ = this->chat_mod_->GetFunction("has_embed");
this->embed_ = this->chat_mod_->GetFunction("embed");
this->prefill_ = this->chat_mod_->GetFunction("prefill");
this->prefill_with_embed_ = this->chat_mod_->GetFunction("prefill_with_embed");
this->decode_ = this->chat_mod_->GetFunction("decode");
this->stopped_ = this->chat_mod_->GetFunction("stopped");
this->get_message_ = this->chat_mod_->GetFunction("get_message");
Expand All @@ -221,10 +218,7 @@ class ChatModule {
this->process_system_prompts_ = this->chat_mod_->GetFunction("process_system_prompts");
this->lib_path_ = "";
this->executable_ = tvm::runtime::Module(nullptr);
ICHECK(has_embed_ != nullptr);
ICHECK(embed_ != nullptr);
ICHECK(prefill_ != nullptr);
ICHECK(prefill_with_embed_ != nullptr);
ICHECK(decode_ != nullptr);
ICHECK(stopped_ != nullptr);
ICHECK(get_message_ != nullptr);
Expand Down Expand Up @@ -272,31 +266,12 @@ class ChatModule {
/*! \return A text describing the runtime statistics. */
std::string RuntimeStatsText() { return runtime_stats_text_(); }

/*!
* \brief Check if embed function is defined.
*/
bool HasEmbed() { return has_embed_(); }

/*!
* \brief Run embedding stage to convert input into an embedding.
* \param input the user input.
*/
tvm::runtime::Array<tvm::runtime::NDArray> Embed(const std::string& input) {
return embed_(input);
}

/*!
* \brief Run prefill stage for a given input and decode the first output token.
* \param input the user input.
*/
void Prefill(const std::string& input) { prefill_(input); }

/*!
* \brief Run prefill stage on an embedding and decode the first output token.
* \param embedding the embedding of user input.
*/
void PrefillWithEmbed(tvm::runtime::NDArray embedding) { prefill_with_embed_(embedding); }

/*!
* \brief Run one decode step to decode the next token.
*/
Expand All @@ -315,10 +290,7 @@ class ChatModule {
protected:
// TVM Modules and functions with TVM's calling convention
tvm::runtime::Module chat_mod_;
tvm::runtime::PackedFunc has_embed_;
tvm::runtime::PackedFunc embed_;
tvm::runtime::PackedFunc prefill_;
tvm::runtime::PackedFunc prefill_with_embed_;
tvm::runtime::PackedFunc decode_;
tvm::runtime::PackedFunc stopped_;
tvm::runtime::PackedFunc get_message_;
Expand Down Expand Up @@ -418,13 +390,7 @@ ModelPaths ModelPaths::Find(const std::filesystem::path& artifact_path,
*/
void Converse(ChatModule* chat, const std::string& input, int stream_interval,
std::ostream& os) { // NOLINT(*)
if (chat->HasEmbed()) {
tvm::runtime::Array<tvm::runtime::NDArray> embed_array = chat->Embed(input);
ICHECK_EQ(embed_array.size(), 1);
chat->PrefillWithEmbed(embed_array[0]);
} else {
chat->Prefill(input);
}
chat->Prefill(input);

std::string cur_msg = "";
std::vector<std::string> cur_utf8_chars = CountUTF8(cur_msg);
Expand Down
22 changes: 0 additions & 22 deletions cpp/conv_templates.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,27 +257,6 @@ Conversation MOSS() {
return conv;
}

Conversation MiniGPT() {
Conversation conv;
conv.name = "minigpt";
conv.system =
("Give the following image: <Img>ImageContent</Img>. "
"You will be able to see the image once I provide it to you. Please answer my questions.");
conv.roles = {"Human", "Assistant"};
conv.messages = {};
conv.offset = 0;
conv.separator_style = SeparatorStyle::kAccumRoleMsg;
conv.seps = {"###"};
conv.role_msg_sep = ": ";
conv.role_empty_sep = ":";
// TODO(mlc-team): add eos to mlc-chat-config
// and remove eos from stop token setting.
conv.stop_tokens = {2};
conv.stop_str = "</s>";
conv.add_bos = true;
return conv;
}

Conversation VanillaLM() {
Conversation conv;
conv.name = "LM";
Expand Down Expand Up @@ -312,7 +291,6 @@ Conversation Conversation::FromTemplate(const std::string& name) {
{"oasst", Oasst},
{"stablelm", StableLM},
{"moss", MOSS},
{"minigpt", MiniGPT},
{"LM", VanillaLM},
};
auto it = factory.find(name);
Expand Down
14 changes: 2 additions & 12 deletions cpp/conversation.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ namespace llm {
enum class SeparatorStyle {
/*! \brief Add separator between role and message. */
kSepRoleMsg,
/*! \brief Accumulate user input and LM output history in the prompts, separated by separators. */
kAccumRoleMsg,
/*! \brief raw language model style, always only returns last message. */
kLM,
};
Expand Down Expand Up @@ -154,9 +152,6 @@ class Conversation {
*/
std::vector<std::string> GetPromptArrayLastRound() {
ICHECK_GE(this->messages.size(), 2);
if (this->separator_style == SeparatorStyle::kAccumRoleMsg) {
return GetPromptArrayInternal(0);
}
return GetPromptArrayInternal(this->messages.size() - 2);
}

Expand Down Expand Up @@ -209,11 +204,7 @@ class Conversation {
const auto& role = item[0];
if (item.size() == 2) {
const std::string message = fproc_message(item[1]);
if (this->separator_style == SeparatorStyle::kAccumRoleMsg && i == start_pos) {
ret.push_back(role + role_msg_sep + "<Img><ImageHere></Img> " + message + end_sep);
} else {
ret.push_back(role + role_msg_sep + message + end_sep);
}
ret.push_back(role + role_msg_sep + message + end_sep);
} else {
ICHECK(item.size() == 1);
ret.push_back(role + role_empty_sep);
Expand All @@ -223,8 +214,7 @@ class Conversation {
}
// dispatcher based on separator style
std::vector<std::string> GetPromptArrayInternal(size_t start_pos) {
if (this->separator_style == SeparatorStyle::kSepRoleMsg ||
this->separator_style == SeparatorStyle::kAccumRoleMsg) {
if (this->separator_style == SeparatorStyle::kSepRoleMsg) {
std::string system_prefix;
if (!this->system.empty()) {
system_prefix = this->system + this->seps[0];
Expand Down
Loading