Skip to content

Commit

Permalink
[LLM Runtime] Enable Mistral-7b (#552)
Browse files Browse the repository at this point in the history
* [LLM Runtime] Enable Mistral-7b

Signed-off-by: intellinjun <jun.lin@intel.com>
  • Loading branch information
intellinjun committed Oct 26, 2023
1 parent 3c8c43b commit 7d14956
Show file tree
Hide file tree
Showing 6 changed files with 1,286 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ LLM Runtime supports the following models:
|[OPT-125m](https://huggingface.co/facebook/opt-125m), [OPT-350m](https://huggingface.co/facebook/opt-350m), [OPT-1.3B](https://huggingface.co/facebook/opt-1.3b), [OPT-13B](https://huggingface.co/facebook/opt-13b)|||
|[ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b), [ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b)|||
|[Baichuan-13B-Chat](https://huggingface.co/baichuan-inc/Baichuan-13B-Chat), [Baichuan2-13B-Chat](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat)|||
|[Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.1)|||

### Code Generation
| model name | INT8 | INT4|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ compile_quant(quant_bloom quant_model.cpp bloom bloom)
compile_quant(quant_chatglm quant_model.cpp chatglm chatglm)
compile_quant(quant_chatglm2 quant_model.cpp chatglm2 chatglm2)
compile_quant(quant_baichuan quant_model.cpp baichuan baichuan)
compile_quant(quant_mistral quant_model.cpp mistral llama)

# all models running
if (NE_PYTHON_API)
Expand All @@ -86,6 +87,7 @@ set(mymap_chatglm2 10)
set(mymap_chatglm 11)
set(mymap_baichuan 12)
set(mymap_polyglot 13)
set(mymap_mistral 14)

function(compile_run TARGET SRC MODEL_NAME MODEL_LIB)
add_executable_w_warning(${TARGET} ${SRC})
Expand Down Expand Up @@ -117,3 +119,4 @@ compile_run(run_bloom main_run.cpp bloom bloom)
compile_run(run_chatglm2 main_run.cpp chatglm2 chatglm2)
compile_run(run_chatglm main_run.cpp chatglm chatglm)
compile_run(run_baichuan main_run.cpp baichuan baichuan)
compile_run(run_mistral main_run.cpp mistral llama)
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,10 @@ PYBIND11_MODULE(baichuan_cpp, m)

PYBIND11_MODULE(polyglot_cpp, m)

#elif MODEL_NAME_ID == 14

PYBIND11_MODULE(mistral_cpp, m)

#endif
{
m.doc() = "cpp model python binding";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,8 @@ void Llama::load(model_context& lctx, model_progress_callback progress_callback,

// qkv GEMM
layer.attn[0] = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend);
if (n_head != n_head_kv) { // In order to distinguish whether it is llama2-70B or not.
layer.attn[1] = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd / n_head_kv}, backend);
layer.attn[2] = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd / n_head_kv}, backend);
} else {
layer.attn[1] = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd}, backend);
layer.attn[2] = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd}, backend);
}

layer.attn[1] = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd / (n_head / n_head_kv)}, backend);
layer.attn[2] = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd / (n_head / n_head_kv)}, backend);
layer.attn[3] = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend);

// ffn norm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,11 +385,11 @@ class model_name_to_arch {
model_name_to_arch() {}
// update this table if has new cpp model
std::unordered_map<std::string, model_archs> name2arch_ = {
{"unknown", MODEL_UNKNOWN}, {"llama", MODEL_LLAMA}, {"gptj", MODEL_GPTJ},
{"mpt", MODEL_MPT}, {"opt", MODEL_OPT}, {"gptneox", MODEL_GPTNEOX},
{"dolly", MODEL_GPTNEOX}, {"polyglot", MODEL_GPTNEOX}, {"starcoder", MODEL_STARCODER},
{"falcon", MODEL_FALCON}, {"bloom", MODEL_BLOOM}, {"chatglm2", MODEL_CHATGLM2},
{"chatglm", MODEL_CHATGLM}, {"baichuan", MODEL_BAICHUAN}};
{"unknown", MODEL_UNKNOWN}, {"llama", MODEL_LLAMA}, {"gptj", MODEL_GPTJ},
{"mpt", MODEL_MPT}, {"opt", MODEL_OPT}, {"gptneox", MODEL_GPTNEOX},
{"dolly", MODEL_GPTNEOX}, {"polyglot", MODEL_GPTNEOX}, {"starcoder", MODEL_STARCODER},
{"falcon", MODEL_FALCON}, {"bloom", MODEL_BLOOM}, {"chatglm2", MODEL_CHATGLM2},
{"chatglm", MODEL_CHATGLM}, {"baichuan", MODEL_BAICHUAN}, {"mistral", MODEL_LLAMA}};
};

#ifdef __cplusplus
Expand Down

0 comments on commit 7d14956

Please sign in to comment.