diff --git a/intel_extension_for_transformers/backends/neural_engine/graph/CMakeLists.txt b/intel_extension_for_transformers/backends/neural_engine/graph/CMakeLists.txt index 2410125b7b5..9f84540d163 100644 --- a/intel_extension_for_transformers/backends/neural_engine/graph/CMakeLists.txt +++ b/intel_extension_for_transformers/backends/neural_engine/graph/CMakeLists.txt @@ -113,6 +113,7 @@ include(cmake/Common.cmake) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) set(COMMON_HEADER_DIRS ${PROJECT_SOURCE_DIR} diff --git a/intel_extension_for_transformers/backends/neural_engine/graph/README.md b/intel_extension_for_transformers/backends/neural_engine/graph/README.md index a356831c877..97f1b25fe1d 100644 --- a/intel_extension_for_transformers/backends/neural_engine/graph/README.md +++ b/intel_extension_for_transformers/backends/neural_engine/graph/README.md @@ -52,11 +52,19 @@ Running LLAMA model, for details please refer to [LLaMA model documentation](./a OMP_NUM_THREADS=56 numactl -m 0 -C 0-55 ./build/bin/main_llama -m ~/llama.cpp/models/ne-model-q4_j.bin --seed 12 -c 512 -b 1024 -n 256 --keep 48 -t 56 --repeat-penalty 1.0 --color -p "She opened the door and see" ``` -Running GPT-NEOX/ MPT / FALCON model, please use `main_gptneox` / `main_mpt` / `main_falcon`. +Running GPT-NEOX/ MPT / FALCON / GPT-J model, please use `main_gptneox` / `main_mpt` / `main_falcon` / `main_gptj`. ```bash OMP_NUM_THREADS=56 numactl -m 0 -C 0-55 ./build/bin/main_gptneox -m ${output_path}/ne-q8.bin --seed 12 -c 512 -b 1024 -n 256 -t 56 --repeat-penalty 1.0 -p "She opened the door and see" ``` +for GPT-J, you can also try python binds which is experimental currently: + +```bash +cp scripts/gptj_binding.py build +cd build +python gptj_binding.py +``` + ### Supported model -Now we supports [GPT-NeoX](https://github.com/EleutherAI/gpt-neox), [LLaMA](https://github.com/facebookresearch/llama), [MPT](https://huggingface.co/mosaicml/mpt-7b), [FALCON](https://huggingface.co/tiiuae/falcon-7b) +Now we supports [GPT-NeoX](https://github.com/EleutherAI/gpt-neox), [LLaMA](https://github.com/facebookresearch/llama), [MPT](https://huggingface.co/mosaicml/mpt-7b), [FALCON](https://huggingface.co/tiiuae/falcon-7b), [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj) diff --git a/intel_extension_for_transformers/backends/neural_engine/graph/application/CMakeLists.txt b/intel_extension_for_transformers/backends/neural_engine/graph/application/CMakeLists.txt index a10d85dff44..f354f84f691 100644 --- a/intel_extension_for_transformers/backends/neural_engine/graph/application/CMakeLists.txt +++ b/intel_extension_for_transformers/backends/neural_engine/graph/application/CMakeLists.txt @@ -19,9 +19,7 @@ add_library(${TARGET} OBJECT common.cpp ) -if (BUILD_SHARED_LIBS) - set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON) -endif() +set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON) target_include_directories(${TARGET} PUBLIC .) target_compile_features(${TARGET} PUBLIC cxx_std_11) @@ -33,3 +31,4 @@ add_subdirectory(ChatLLAMA) add_subdirectory(ChatGPTNEOX) add_subdirectory(ChatMPT) add_subdirectory(ChatFALCON) +add_subdirectory(ChatGPTJ) diff --git a/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatGPTJ/CMakeLists.txt b/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatGPTJ/CMakeLists.txt new file mode 100644 index 00000000000..63cbfa8c930 --- /dev/null +++ b/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatGPTJ/CMakeLists.txt @@ -0,0 +1,40 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set(TARGET main_gptj) +add_executable_w_warning(${TARGET} main_gptj.cpp) +target_link_libraries(${TARGET} PUBLIC ne_layers common ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) +if(TARGET BUILD_INFO) + add_dependencies(${TARGET} BUILD_INFO) +endif() + +set(TARGET quant_gptj) +add_executable_w_warning(${TARGET} quant_gptj.cpp) +target_link_libraries(${TARGET} PUBLIC ne_layers common ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) +if(TARGET BUILD_INFO) + add_dependencies(${TARGET} BUILD_INFO) +endif() + +set(TARGET GptjPyBind) +add_library_w_warning(${TARGET} SHARED pybind_gptj.cpp) +target_link_libraries(${TARGET} PUBLIC ne_layers common ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) +set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON) + +set(TARGET pybind_gptj) +add_executable_w_warning(${TARGET} pybind_gptj.cpp) +target_link_libraries(${TARGET} PUBLIC ne_layers common ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatGPTJ/main_gptj.cpp b/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatGPTJ/main_gptj.cpp new file mode 100644 index 00000000000..ee64de4f31a --- /dev/null +++ b/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatGPTJ/main_gptj.cpp @@ -0,0 +1,691 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "core/ne_layers.h" +#include "data_types.h" +#include "ne.h" + +#if defined(_MSC_VER) +#pragma warning(disable : 4244 4267) // possible loss of data +#endif + +// default hparams (GPT-J 6B) +struct gptj_hparams { + int32_t n_vocab = 50400; + int32_t n_ctx = 2048; + int32_t n_embd = 4096; + int32_t n_head = 16; + int32_t n_layer = 28; + int32_t n_rot = 64; + int32_t ftype = 1; +}; + +struct gptj_layer { + // normalization + struct ne_tensor* ln_1_g; + struct ne_tensor* ln_1_b; + + // attention + struct ne_tensor* c_attn_q_proj_w; + struct ne_tensor* c_attn_k_proj_w; + struct ne_tensor* c_attn_v_proj_w; + + struct ne_tensor* c_attn_proj_w; + + // ff + struct ne_tensor* c_mlp_fc_w; + struct ne_tensor* c_mlp_fc_b; + + struct ne_tensor* c_mlp_proj_w; + struct ne_tensor* c_mlp_proj_b; +}; + +struct gptj_model { + gptj_hparams hparams; + + // normalization + struct ne_tensor* ln_f_g; + struct ne_tensor* ln_f_b; + + struct ne_tensor* wte; // position embedding + + struct ne_tensor* lmh_g; // language model head + struct ne_tensor* lmh_b; // language model bias + + std::vector layers; + + // key + value memory + struct ne_tensor* memory_k; + struct ne_tensor* memory_v; + + // + struct ne_context* ctx; + std::map tensors; +}; + +// load the model's weights from a file +bool gptj_model_load(const std::string& fname, gptj_model& model, gpt_vocab& vocab) { + printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); + + auto fin = std::ifstream(fname, std::ios::binary); + if (!fin) { + fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str()); + return false; + } + + uint32_t magic; + fin.read((char *) &magic, sizeof(magic)); + if (magic != NE_FILE_MAGIC) { + fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); + return false; + } + // load hparams + auto& hparams = model.hparams; + + fin.read((char*)&hparams.n_vocab, sizeof(hparams.n_vocab)); + fin.read((char*)&hparams.n_ctx, sizeof(hparams.n_ctx)); + fin.read((char*)&hparams.n_embd, sizeof(hparams.n_embd)); + fin.read((char*)&hparams.n_head, sizeof(hparams.n_head)); + fin.read((char*)&hparams.n_layer, sizeof(hparams.n_layer)); + fin.read((char*)&hparams.n_rot, sizeof(hparams.n_rot)); + fin.read((char*)&hparams.ftype, sizeof(hparams.ftype)); + + const int32_t qntvr = hparams.ftype / NE_QNT_VERSION_FACTOR; + + printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx); + printf("%s: n_embd = %d\n", __func__, hparams.n_embd); + printf("%s: n_head = %d\n", __func__, hparams.n_head); + printf("%s: n_layer = %d\n", __func__, hparams.n_layer); + printf("%s: n_rot = %d\n", __func__, hparams.n_rot); + printf("%s: ftype = %d\n", __func__, hparams.ftype); + printf("%s: qntvr = %d\n", __func__, qntvr); + + hparams.ftype %= NE_QNT_VERSION_FACTOR; + + // load vocab + int32_t n_vocab = 0; + fin.read((char*)&n_vocab, sizeof(n_vocab)); + + if (n_vocab != model.hparams.n_vocab) { + fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", __func__, fname.c_str(), n_vocab, + model.hparams.n_vocab); + return false; + } + + std::string word; + std::vector buf(128); + + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + fin.read((char*)&len, sizeof(len)); + + buf.resize(len); + fin.read((char*)buf.data(), len); + word.assign(buf.data(), len); + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + } + + // for the big tensors, we have the option to store the data in 16-bit floats or quantized + // in order to save memory and also to speed up the computation + ne_type wtype = ne_ftype_to_ne_type((ne_ftype)(model.hparams.ftype)); + if (wtype == NE_TYPE_COUNT) { + fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n", __func__, fname.c_str(), model.hparams.ftype); + return false; + } + + auto& ctx = model.ctx; + + size_t ctx_size = 0; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + + ctx_size += n_embd * ne_type_sizef(NE_TYPE_F32); // ln_f_g + ctx_size += n_embd * ne_type_sizef(NE_TYPE_F32); // ln_f_b + + ctx_size += n_embd * n_vocab * ne_type_sizef(wtype); // wte + + ctx_size += n_embd * n_vocab * ne_type_sizef(wtype); // lmh_g + ctx_size += n_vocab * ne_type_sizef(NE_TYPE_F32); // lmh_b + + ctx_size += n_layer * (n_embd * ne_type_sizef(NE_TYPE_F32)); // ln_1_g + ctx_size += n_layer * (n_embd * ne_type_sizef(NE_TYPE_F32)); // ln_1_b + + ctx_size += n_layer * (n_embd * n_embd * ne_type_sizef(wtype)); // c_attn_q_proj_w + ctx_size += n_layer * (n_embd * n_embd * ne_type_sizef(wtype)); // c_attn_k_proj_w + ctx_size += n_layer * (n_embd * n_embd * ne_type_sizef(wtype)); // c_attn_v_proj_w + + ctx_size += n_layer * (n_embd * n_embd * ne_type_sizef(wtype)); // c_attn_proj_w + + ctx_size += n_layer * (4 * n_embd * n_embd * ne_type_sizef(wtype)); // c_mlp_fc_w + ctx_size += n_layer * (4 * n_embd * ne_type_sizef(NE_TYPE_F32)); // c_mlp_fc_b + + ctx_size += n_layer * (4 * n_embd * n_embd * ne_type_sizef(wtype)); // c_mlp_proj_w + ctx_size += n_layer * (n_embd * ne_type_sizef(NE_TYPE_F32)); // c_mlp_proj_b + + ctx_size += n_ctx * n_layer * n_embd * ne_type_sizef(NE_TYPE_F16); // memory_k + ctx_size += n_ctx * n_layer * n_embd * ne_type_sizef(NE_TYPE_F16); // memory_v + + ctx_size += (5 + 10 * n_layer) * 512; // object overhead + + printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size / (1024.0 * 1024.0)); + + // create the ggml context + struct ne_init_params params = { + /*.mem_size =*/ctx_size, + /*.mem_buffer =*/NULL, + /*.no_alloc =*/false, + }; + + model.ctx = ne_init(params); + if (!model.ctx) { + fprintf(stderr, "%s: ne_init() failed\n", __func__); + return false; + } + + // prepare memory for the weights + + model.layers.resize(n_layer); + + model.wte = d_ne_new_tensor_2d(ctx, wtype, n_embd, n_vocab); + + model.ln_f_g = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_embd); + model.ln_f_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_embd); + + model.lmh_g = d_ne_new_tensor_2d(ctx, wtype, n_embd, n_vocab); + model.lmh_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_vocab); + + // map by name + model.tensors["transformer.wte.weight"] = model.wte; + + model.tensors["transformer.ln_f.weight"] = model.ln_f_g; + model.tensors["transformer.ln_f.bias"] = model.ln_f_b; + + model.tensors["lm_head.weight"] = model.lmh_g; + model.tensors["lm_head.bias"] = model.lmh_b; + + for (int i = 0; i < n_layer; ++i) { + auto& layer = model.layers[i]; + + layer.ln_1_g = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_embd); + layer.ln_1_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_embd); + + layer.c_attn_q_proj_w = d_ne_new_tensor_2d(ctx, wtype, n_embd, n_embd); + layer.c_attn_k_proj_w = d_ne_new_tensor_2d(ctx, wtype, n_embd, n_embd); + layer.c_attn_v_proj_w = d_ne_new_tensor_2d(ctx, wtype, n_embd, n_embd); + + layer.c_attn_proj_w = d_ne_new_tensor_2d(ctx, wtype, n_embd, n_embd); + + layer.c_mlp_fc_w = d_ne_new_tensor_2d(ctx, wtype, n_embd, 4 * n_embd); + layer.c_mlp_fc_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, 4 * n_embd); + + layer.c_mlp_proj_w = d_ne_new_tensor_2d(ctx, wtype, 4 * n_embd, n_embd); + layer.c_mlp_proj_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_embd); + + // map by name + model.tensors["transformer.h." + std::to_string(i) + ".ln_1.weight"] = layer.ln_1_g; + model.tensors["transformer.h." + std::to_string(i) + ".ln_1.bias"] = layer.ln_1_b; + + model.tensors["transformer.h." + std::to_string(i) + ".attn.q_proj.weight"] = layer.c_attn_q_proj_w; + model.tensors["transformer.h." + std::to_string(i) + ".attn.k_proj.weight"] = layer.c_attn_k_proj_w; + model.tensors["transformer.h." + std::to_string(i) + ".attn.v_proj.weight"] = layer.c_attn_v_proj_w; + + model.tensors["transformer.h." + std::to_string(i) + ".attn.out_proj.weight"] = layer.c_attn_proj_w; + + model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_in.weight"] = layer.c_mlp_fc_w; + model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_in.bias"] = layer.c_mlp_fc_b; + + model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_out.weight"] = layer.c_mlp_proj_w; + model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_out.bias"] = layer.c_mlp_proj_b; + } + + // key + value memory + const int n_mem = n_layer * n_ctx; + const int n_elements = n_embd * n_mem; + + model.memory_k = d_ne_new_tensor_1d(ctx, NE_TYPE_F16, n_elements); + model.memory_v = d_ne_new_tensor_1d(ctx, NE_TYPE_F16, n_elements); + + const size_t memory_size = ne_nbytes(model.memory_k) + ne_nbytes(model.memory_v); + + printf("%s: memory_size = %8.2f MB, n_mem = %d\n", __func__, memory_size / 1024.0 / 1024.0, n_mem); + + // load weights + int n_tensors = 0; + size_t total_size = 0; + + printf("%s: ", __func__); + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ttype; + + fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + fin.read(reinterpret_cast(&length), sizeof(length)); + fin.read(reinterpret_cast(&ttype), sizeof(ttype)); + + if (fin.eof()) { + break; + } + + int32_t nelements = 1; + int32_t ne[2] = {1, 1}; + for (int i = 0; i < n_dims; ++i) { + fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); + nelements *= ne[i]; + } + + std::string name(length, 0); + fin.read(&name[0], length); + + if (model.tensors.find(name.data()) == model.tensors.end()) { + fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return false; + } + + auto tensor = model.tensors[name.data()]; + if (ne_nelements(tensor) != nelements) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + return false; + } + + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) { + fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", __func__, + name.data(), (int)tensor->ne[0], (int)tensor->ne[1], ne[0], ne[1]); + return false; + } + + // for debugging + if (0) { + printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], + ne_type_name(ne_type(ttype)), ne_nbytes(tensor) / 1024.0 / 1024.0, ne_nbytes(tensor)); + } + + const size_t bpe = ne_type_size(ne_type(ttype)); + + if ((nelements * bpe) / ne_blck_size(tensor->type) != ne_nbytes(tensor)) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", __func__, name.data(), + ne_nbytes(tensor), nelements * bpe); + return false; + } + + fin.read(reinterpret_cast(tensor->data), ne_nbytes(tensor)); + + // printf("%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ttype == 0 ? "float" : "f16", + // ne_nbytes(tensor)/1024.0/1024.0); + total_size += ne_nbytes(tensor); + if (++n_tensors % 8 == 0) { + printf("."); + fflush(stdout); + } + } + + fin.close(); + + return true; +} + +// evaluate the transformer +// +// - model: the model +// - n_threads: number of threads to use +// - n_past: the context size so far +// - embd_inp: the embeddings of the tokens in the context +// - embd_w: the predicted logits for the next token +// +// The GPT-J model requires about 16MB of memory per input token. +// +bool gptj_eval(const gptj_model& model, const int n_threads, const int n_past, + const std::vector& embd_inp, std::vector& embd_w, size_t& mem_per_token) { + const int N = embd_inp.size(); + + const auto& hparams = model.hparams; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + const int n_head = hparams.n_head; + const int n_vocab = hparams.n_vocab; + const int n_rot = hparams.n_rot; + + static size_t buf_size = 256u * 1024 * 1024; + static void* buf = malloc(buf_size); + + if (mem_per_token > 0 && mem_per_token * N > buf_size) { + const size_t buf_size_new = 1.1 * (mem_per_token * N); // add 10% to account for ggml object overhead + // printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new); + + // reallocate + buf_size = buf_size_new; + buf = realloc(buf, buf_size); + if (buf == nullptr) { + fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size); + return false; + } + } + + struct ne_init_params params = { + /*.mem_size =*/buf_size, + /*.mem_buffer =*/buf, + /*.no_alloc =*/false, + }; + + struct ne_context* ctx0 = ne_init(params); + struct ne_cgraph gf = {}; + gf.n_threads = n_threads; + + struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N); + memcpy(embd->data, embd_inp.data(), N * ne_element_size(embd)); + + // wte + struct ne_tensor* inpL = ne_get_rows(ctx0, model.wte, embd); + + for (int il = 0; il < n_layer; ++il) { + struct ne_tensor* cur; + + // norm + { + cur = ne_norm(ctx0, inpL); + + // cur = ln_1_g*cur + ln_1_b + cur = ne_add(ctx0, ne_mul(ctx0, ne_repeat(ctx0, model.layers[il].ln_1_g, cur), cur), + ne_repeat(ctx0, model.layers[il].ln_1_b, cur)); + } + + struct ne_tensor* inpSA = cur; + + // self-attention + { + struct ne_tensor* Qcur = ne_rope_inplace( + ctx0, + ne_reshape_3d(ctx0, ne_mul_mat(ctx0, model.layers[il].c_attn_q_proj_w, cur), n_embd / n_head, n_head, N), + n_past, n_rot, 0); + struct ne_tensor* Kcur = ne_rope_inplace( + ctx0, + ne_reshape_3d(ctx0, ne_mul_mat(ctx0, model.layers[il].c_attn_k_proj_w, cur), n_embd / n_head, n_head, N), + n_past, n_rot, 0); + + // store key and value to memory + { + struct ne_tensor* Vcur = ne_transpose(ctx0, ne_mul_mat(ctx0, model.layers[il].c_attn_v_proj_w, cur)); + + struct ne_tensor* k = ne_view_1d(ctx0, model.memory_k, N * n_embd, + (ne_element_size(model.memory_k) * n_embd) * (il * n_ctx + n_past)); + struct ne_tensor* v = ne_view_2d( + ctx0, model.memory_v, N, n_embd, (n_ctx)*ne_element_size(model.memory_v), + (il * n_ctx) * ne_element_size(model.memory_v) * n_embd + n_past * ne_element_size(model.memory_v)); + + ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur, k)); + ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur, v)); + } + + // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) + struct ne_tensor* Q = ne_permute(ctx0, Qcur, 0, 2, 1, 3); + + // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) + struct ne_tensor* K = ne_permute(ctx0, + ne_reshape_3d(ctx0, + ne_view_1d(ctx0, model.memory_k, (n_past + N) * n_embd, + il * n_ctx * ne_element_size(model.memory_k) * n_embd), + n_embd / n_head, n_head, n_past + N), + 0, 2, 1, 3); + + // K * Q + struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + struct ne_tensor* KQ_scaled = ne_scale_inplace(ctx0, KQ, ne_new_f32(ctx0, 1.0f / sqrt(float(n_embd) / n_head))); + + // KQ_masked = mask_past(KQ_scaled) + struct ne_tensor* KQ_masked = ne_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + + // KQ = soft_max(KQ_masked) + struct ne_tensor* KQ_soft_max = ne_soft_max_inplace(ctx0, KQ_masked); + + // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() + struct ne_tensor* V = + ne_view_3d(ctx0, model.memory_v, n_past + N, n_embd / n_head, n_head, n_ctx * ne_element_size(model.memory_v), + n_ctx * ne_element_size(model.memory_v) * n_embd / n_head, + il * n_ctx * ne_element_size(model.memory_v) * n_embd); + + // KQV = transpose(V) * KQ_soft_max + struct ne_tensor* KQV = ne_mul_mat(ctx0, V, KQ_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_embd, N) + cur = ne_cpy(ctx0, KQV_merged, d_ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N)); + + // projection (no bias) + cur = ne_mul_mat(ctx0, model.layers[il].c_attn_proj_w, cur); + } + + struct ne_tensor* inpFF = cur; + + // feed-forward network + // this is independent of the self-attention result, so it could be done in parallel to the self-attention + { + // note here we pass inpSA instead of cur + cur = ne_mul_mat(ctx0, model.layers[il].c_mlp_fc_w, inpSA); + + cur = ne_add(ctx0, ne_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur), cur); + + // GELU activation + cur = ne_gelu(ctx0, cur); + + // projection + // cur = proj_w*cur + proj_b + cur = ne_mul_mat(ctx0, model.layers[il].c_mlp_proj_w, cur); + + cur = ne_add(ctx0, ne_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur), cur); + } + + // self-attention + FF + cur = ne_add(ctx0, cur, inpFF); + + // input for next layer + inpL = ne_add(ctx0, cur, inpL); + } + + // norm + { + inpL = ne_norm(ctx0, inpL); + + // inpL = ln_f_g*inpL + ln_f_b + inpL = ne_add(ctx0, ne_mul(ctx0, ne_repeat(ctx0, model.ln_f_g, inpL), inpL), ne_repeat(ctx0, model.ln_f_b, inpL)); + } + + // lm_head + { + inpL = ne_mul_mat(ctx0, model.lmh_g, inpL); + + inpL = ne_add(ctx0, ne_repeat(ctx0, model.lmh_b, inpL), inpL); + } + + // logits -> probs + // inpL = ne_soft_max_inplace(ctx0, inpL); + + // run the computation + ne_build_forward_expand(&gf, inpL); + ne_graph_compute(ctx0, &gf); + + // return result for just the last token + embd_w.resize(n_vocab); + memcpy(embd_w.data(), (float*)ne_get_data(inpL) + (n_vocab * (N - 1)), sizeof(float) * n_vocab); + + if (mem_per_token == 0) { + mem_per_token = ne_used_mem(ctx0) / N; + } + // printf("used_mem = %zu\n", ne_used_mem(ctx0)); + + ne_free(ctx0); + + return true; +} + +int main(int argc, char** argv) { + ne_time_init(); + + const int64_t t_main_start_us = ne_time_us(); + + common_params params; + params.model = "models/gpt-j-6B/ggml-model.bin"; + + if (common_params_parse(argc, argv, params) == false) { + return 1; + } + + if (params.seed < 0) { + params.seed = time(NULL); + } + + printf("%s: seed = %d\n", __func__, params.seed); + + std::mt19937 rng(params.seed); + if (params.prompt.empty()) { + params.prompt = gpt_random_prompt(rng); + } + + int64_t t_load_us = 0; + + gpt_vocab vocab; + gptj_model model; + + // load the model + { + const int64_t t_start_us = ne_time_us(); + + if (!gptj_model_load(params.model, model, vocab)) { + fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); + return 1; + } + + t_load_us = ne_time_us() - t_start_us; + + test_gpt_tokenizer(vocab, params.token_test); + } + + int n_past = 0; + + int64_t t_sample_us = 0; + int64_t t_predict_us = 0; + + std::vector logits; + + // tokenize the prompt + std::vector embd_inp = ::gpt_tokenize(vocab, params.prompt); + + params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int)embd_inp.size()); + + printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); + printf("\n"); + + std::vector embd; + + // determine the required inference memory per token: + size_t mem_per_token = 0; + gptj_eval(model, params.n_threads, 0, {0, 1, 2, 3}, logits, mem_per_token); + + for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) { + // predict + if (embd.size() > 0) { + const int64_t t_start_us = ne_time_us(); + + if (!gptj_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) { + printf("Failed to predict\n"); + return 1; + } + + t_predict_us += ne_time_us() - t_start_us; + } + + n_past += embd.size(); + embd.clear(); + + if (i >= embd_inp.size()) { + // sample next token + const int top_k = params.top_k; + const float top_p = params.top_p; + const float temp = params.temp; + + const int n_vocab = model.hparams.n_vocab; + + gpt_vocab::id id = 0; + + { + const int64_t t_start_sample_us = ne_time_us(); + + id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng); + + t_sample_us += ne_time_us() - t_start_sample_us; + } + + // add it to the context + embd.push_back(id); + } else { + // if here, it means we are still processing the input prompt + for (int k = i; k < embd_inp.size(); k++) { + embd.push_back(embd_inp[k]); + if (embd.size() > params.n_batch) { + break; + } + } + i += embd.size() - 1; + } + + // display text + for (auto id : embd) { + printf("%s", vocab.id_to_token[id].c_str()); + } + fflush(stdout); + + // end of text token + if (embd.back() == 50256) { + break; + } + } + + // report timing + { + const int64_t t_main_end_us = ne_time_us(); + + printf("\n\n"); + printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token); + printf("%s: load time = %8.2f ms\n", __func__, t_load_us / 1000.0f); + printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us / 1000.0f); + printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us / 1000.0f, + t_predict_us / 1000.0f / n_past); + printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f); + } + + ne_free(model.ctx); + + return 0; +} diff --git a/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatGPTJ/pybind_gptj.cpp b/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatGPTJ/pybind_gptj.cpp new file mode 100644 index 00000000000..dcda0239261 --- /dev/null +++ b/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatGPTJ/pybind_gptj.cpp @@ -0,0 +1,741 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "core/ne_layers.h" +#include "data_types.h" +#include "ne.h" + +#if defined(_MSC_VER) +#pragma warning(disable : 4244 4267) // possible loss of data +#endif + +#define N_thread 56 + +// default hparams (GPT-J 6B) +struct gptj_hparams { + int32_t n_vocab = 50400; + int32_t n_ctx = 2048; + int32_t n_embd = 4096; + int32_t n_head = 16; + int32_t n_layer = 28; + int32_t n_rot = 64; + int32_t ftype = 1; +}; + +struct gptj_layer { + // normalization + struct ne_tensor* ln_1_g; + struct ne_tensor* ln_1_b; + + // attention + struct ne_tensor* c_attn_q_proj_w; + struct ne_tensor* c_attn_k_proj_w; + struct ne_tensor* c_attn_v_proj_w; + + struct ne_tensor* c_attn_proj_w; + + // ff + struct ne_tensor* c_mlp_fc_w; + struct ne_tensor* c_mlp_fc_b; + + struct ne_tensor* c_mlp_proj_w; + struct ne_tensor* c_mlp_proj_b; +}; + +struct gptj_model { + gptj_hparams hparams; + + // normalization + struct ne_tensor* ln_f_g; + struct ne_tensor* ln_f_b; + + struct ne_tensor* wte; // position embedding + + struct ne_tensor* lmh_g; // language model head + struct ne_tensor* lmh_b; // language model bias + + std::vector layers; + + // key + value memory + struct ne_tensor* memory_k; + struct ne_tensor* memory_v; + + // + struct ne_context* ctx; + std::map tensors; + + ~gptj_model() { ne_free(ctx); } +}; + +struct gptj_all { + struct gptj_model* model; + struct gpt_vocab* vocab; + ~gptj_all() { + delete model; + delete vocab; + } +}; + +// load the model's weights from a file +bool gptj_model_load(const std::string& fname, gptj_model& model, gpt_vocab& vocab) { + printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); + + auto fin = std::ifstream(fname, std::ios::binary); + if (!fin) { + fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str()); + return false; + } + + // verify magic + uint32_t magic; + fin.read((char*)&magic, sizeof(magic)); + if (magic != 0x67676d6c) { + fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); + return false; + } + + // load hparams + auto& hparams = model.hparams; + + fin.read((char*)&hparams.n_vocab, sizeof(hparams.n_vocab)); + fin.read((char*)&hparams.n_ctx, sizeof(hparams.n_ctx)); + fin.read((char*)&hparams.n_embd, sizeof(hparams.n_embd)); + fin.read((char*)&hparams.n_head, sizeof(hparams.n_head)); + fin.read((char*)&hparams.n_layer, sizeof(hparams.n_layer)); + fin.read((char*)&hparams.n_rot, sizeof(hparams.n_rot)); + fin.read((char*)&hparams.ftype, sizeof(hparams.ftype)); + + const int32_t qntvr = hparams.ftype / NE_QNT_VERSION_FACTOR; + + printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx); + printf("%s: n_embd = %d\n", __func__, hparams.n_embd); + printf("%s: n_head = %d\n", __func__, hparams.n_head); + printf("%s: n_layer = %d\n", __func__, hparams.n_layer); + printf("%s: n_rot = %d\n", __func__, hparams.n_rot); + printf("%s: ftype = %d\n", __func__, hparams.ftype); + printf("%s: qntvr = %d\n", __func__, qntvr); + + hparams.ftype %= NE_QNT_VERSION_FACTOR; + + // load vocab + int32_t file_n_vocab = 0; + fin.read((char*)&file_n_vocab, sizeof(file_n_vocab)); + + if (file_n_vocab != model.hparams.n_vocab) { + fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", __func__, fname.c_str(), file_n_vocab, + model.hparams.n_vocab); + return false; + } + + std::string word; + std::vector buf(128); + + for (int i = 0; i < file_n_vocab; i++) { + uint32_t len; + fin.read((char*)&len, sizeof(len)); + + buf.resize(len); + fin.read((char*)buf.data(), len); + word.assign(buf.data(), len); + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + } + + // for the big tensors, we have the option to store the data in 16-bit floats or quantized + // in order to save memory and also to speed up the computation + ne_type wtype = ne_ftype_to_ne_type((ne_ftype)(model.hparams.ftype)); + if (wtype == NE_TYPE_COUNT) { + fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n", __func__, fname.c_str(), model.hparams.ftype); + return false; + } + + auto& ctx = model.ctx; + + size_t ctx_size = 0; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + const int n_vocab = hparams.n_vocab; + + ctx_size += n_embd * ne_type_sizef(NE_TYPE_F32); // ln_f_g + ctx_size += n_embd * ne_type_sizef(NE_TYPE_F32); // ln_f_b + + ctx_size += n_embd * n_vocab * ne_type_sizef(wtype); // wte + + ctx_size += n_embd * n_vocab * ne_type_sizef(wtype); // lmh_g + ctx_size += n_vocab * ne_type_sizef(NE_TYPE_F32); // lmh_b + + ctx_size += n_layer * (n_embd * ne_type_sizef(NE_TYPE_F32)); // ln_1_g + ctx_size += n_layer * (n_embd * ne_type_sizef(NE_TYPE_F32)); // ln_1_b + + ctx_size += n_layer * (n_embd * n_embd * ne_type_sizef(wtype)); // c_attn_q_proj_w + ctx_size += n_layer * (n_embd * n_embd * ne_type_sizef(wtype)); // c_attn_k_proj_w + ctx_size += n_layer * (n_embd * n_embd * ne_type_sizef(wtype)); // c_attn_v_proj_w + + ctx_size += n_layer * (n_embd * n_embd * ne_type_sizef(wtype)); // c_attn_proj_w + + ctx_size += n_layer * (4 * n_embd * n_embd * ne_type_sizef(wtype)); // c_mlp_fc_w + ctx_size += n_layer * (4 * n_embd * ne_type_sizef(NE_TYPE_F32)); // c_mlp_fc_b + + ctx_size += n_layer * (4 * n_embd * n_embd * ne_type_sizef(wtype)); // c_mlp_proj_w + ctx_size += n_layer * (n_embd * ne_type_sizef(NE_TYPE_F32)); // c_mlp_proj_b + + ctx_size += n_ctx * n_layer * n_embd * ne_type_sizef(NE_TYPE_F16); // memory_k + ctx_size += n_ctx * n_layer * n_embd * ne_type_sizef(NE_TYPE_F16); // memory_v + + ctx_size += (5 + 10 * n_layer) * 512; // object overhead + + struct ne_init_params params = { + /*.mem_size =*/ctx_size, + /*.mem_buffer =*/NULL, + /*.no_alloc =*/false, + }; + + model.ctx = ne_init(params); + if (!model.ctx) { + fprintf(stderr, "%s: ne_init() failed\n", __func__); + return false; + } + + // prepare memory for the weights + + model.layers.resize(n_layer); + + model.wte = d_ne_new_tensor_2d(ctx, wtype, n_embd, n_vocab); + + model.ln_f_g = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_embd); + model.ln_f_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_embd); + + model.lmh_g = d_ne_new_tensor_2d(ctx, wtype, n_embd, n_vocab); + model.lmh_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_vocab); + + // map by name + model.tensors["transformer.wte.weight"] = model.wte; + + model.tensors["transformer.ln_f.weight"] = model.ln_f_g; + model.tensors["transformer.ln_f.bias"] = model.ln_f_b; + + model.tensors["lm_head.weight"] = model.lmh_g; + model.tensors["lm_head.bias"] = model.lmh_b; + + for (int i = 0; i < n_layer; ++i) { + auto& layer = model.layers[i]; + + layer.ln_1_g = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_embd); + layer.ln_1_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_embd); + + layer.c_attn_q_proj_w = d_ne_new_tensor_2d(ctx, wtype, n_embd, n_embd); + layer.c_attn_k_proj_w = d_ne_new_tensor_2d(ctx, wtype, n_embd, n_embd); + layer.c_attn_v_proj_w = d_ne_new_tensor_2d(ctx, wtype, n_embd, n_embd); + + layer.c_attn_proj_w = d_ne_new_tensor_2d(ctx, wtype, n_embd, n_embd); + + layer.c_mlp_fc_w = d_ne_new_tensor_2d(ctx, wtype, n_embd, 4 * n_embd); + layer.c_mlp_fc_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, 4 * n_embd); + + layer.c_mlp_proj_w = d_ne_new_tensor_2d(ctx, wtype, 4 * n_embd, n_embd); + layer.c_mlp_proj_b = d_ne_new_tensor_1d(ctx, NE_TYPE_F32, n_embd); + + // map by name + model.tensors["transformer.h." + std::to_string(i) + ".ln_1.weight"] = layer.ln_1_g; + model.tensors["transformer.h." + std::to_string(i) + ".ln_1.bias"] = layer.ln_1_b; + + model.tensors["transformer.h." + std::to_string(i) + ".attn.q_proj.weight"] = layer.c_attn_q_proj_w; + model.tensors["transformer.h." + std::to_string(i) + ".attn.k_proj.weight"] = layer.c_attn_k_proj_w; + model.tensors["transformer.h." + std::to_string(i) + ".attn.v_proj.weight"] = layer.c_attn_v_proj_w; + + model.tensors["transformer.h." + std::to_string(i) + ".attn.out_proj.weight"] = layer.c_attn_proj_w; + + model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_in.weight"] = layer.c_mlp_fc_w; + model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_in.bias"] = layer.c_mlp_fc_b; + + model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_out.weight"] = layer.c_mlp_proj_w; + model.tensors["transformer.h." + std::to_string(i) + ".mlp.fc_out.bias"] = layer.c_mlp_proj_b; + } + + // key + value memory + const int n_mem = n_layer * n_ctx; + const int n_elements = n_embd * n_mem; + model.memory_k = d_ne_new_tensor_1d(ctx, NE_TYPE_F16, n_elements); + model.memory_v = d_ne_new_tensor_1d(ctx, NE_TYPE_F16, n_elements); + + const size_t memory_size = ne_nbytes(model.memory_k) + ne_nbytes(model.memory_v); + + printf("%s: memory_size = %8.2f MB, n_mem = %d\n", __func__, memory_size / 1024.0 / 1024.0, n_mem); + + // load weights + int n_tensors = 0; + size_t total_size = 0; + + printf("%s: ", __func__); + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ttype; + + fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + fin.read(reinterpret_cast(&length), sizeof(length)); + fin.read(reinterpret_cast(&ttype), sizeof(ttype)); + + if (fin.eof()) { + break; + } + + int32_t nelements = 1; + int32_t ne[2] = {1, 1}; + for (int i = 0; i < n_dims; ++i) { + fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); + nelements *= ne[i]; + } + + std::string name(length, 0); + fin.read(&name[0], length); + + if (model.tensors.find(name.data()) == model.tensors.end()) { + fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return false; + } + + auto tensor = model.tensors[name.data()]; + if (ne_nelements(tensor) != nelements) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + return false; + } + + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) { + fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", __func__, + name.data(), (int)tensor->ne[0], (int)tensor->ne[1], ne[0], ne[1]); + return false; + } + + // for debugging + if (0) { + printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], + ne_type_name(ne_type(ttype)), ne_nbytes(tensor) / 1024.0 / 1024.0, ne_nbytes(tensor)); + } + + const size_t bpe = ne_type_size(ne_type(ttype)); + + if ((nelements * bpe) / ne_blck_size(tensor->type) != ne_nbytes(tensor)) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", __func__, name.data(), + ne_nbytes(tensor), nelements * bpe); + return false; + } + + fin.read(reinterpret_cast(tensor->data), ne_nbytes(tensor)); + + // printf("%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ttype == 0 ? "float" : "f16", + // ne_nbytes(tensor)/1024.0/1024.0); + total_size += ne_nbytes(tensor); + if (++n_tensors % 8 == 0) { + printf("."); + fflush(stdout); + } + } + printf(" done\n"); + printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size / 1024.0 / 1024.0, n_tensors); + + fin.close(); + + return true; +} + +// evaluate the transformer +// +// - model: the model +// - n_threads: number of threads to use +// - n_past: the context size so far +// - embd_inp: the embeddings of the tokens in the context +// - embd_w: the predicted logits for the next token +// +// The GPT-J model requires about 16MB of memory per input token. +// +bool gptj_eval(const gptj_model& model, const int n_threads, const int n_past, + const std::vector& embd_inp, std::vector& embd_w, size_t& mem_per_token) { + const int N = embd_inp.size(); + + const auto& hparams = model.hparams; + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = hparams.n_ctx; + const int n_head = hparams.n_head; + const int n_vocab = hparams.n_vocab; + const int n_rot = hparams.n_rot; + + static size_t buf_size = 256u * 1024 * 1024; + static void* buf = malloc(buf_size); + + if (mem_per_token > 0 && mem_per_token * N > buf_size) { + const size_t buf_size_new = 1.1 * (mem_per_token * N); // add 10% to account for ggml object overhead + + // reallocate + buf_size = buf_size_new; + buf = realloc(buf, buf_size); + if (buf == nullptr) { + fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size); + return false; + } + } + + struct ne_init_params params = { + /*.mem_size =*/buf_size, + /*.mem_buffer =*/buf, + /*.no_alloc =*/false, + }; + + struct ne_context* ctx0 = ne_init(params); + struct ne_cgraph gf = {}; + gf.n_threads = n_threads; + + struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N); + memcpy(embd->data, embd_inp.data(), N * ne_element_size(embd)); + + // wte + struct ne_tensor* inpL = ne_get_rows(ctx0, model.wte, embd); + + for (int il = 0; il < n_layer; ++il) { + struct ne_tensor* cur; + + // norm + cur = ne_norm(ctx0, inpL); + + // cur = ln_1_g*cur + ln_1_b + cur = ne_add(ctx0, ne_mul(ctx0, ne_repeat(ctx0, model.layers[il].ln_1_g, cur), cur), + ne_repeat(ctx0, model.layers[il].ln_1_b, cur)); + + struct ne_tensor* inpSA = cur; + + // self-attention + struct ne_tensor* Qcur = ne_rope_inplace( + ctx0, ne_reshape_3d(ctx0, ne_mul_mat(ctx0, model.layers[il].c_attn_q_proj_w, cur), n_embd / n_head, n_head, N), + n_past, n_rot, 0); + struct ne_tensor* Kcur = ne_rope_inplace( + ctx0, ne_reshape_3d(ctx0, ne_mul_mat(ctx0, model.layers[il].c_attn_k_proj_w, cur), n_embd / n_head, n_head, N), + n_past, n_rot, 0); + + // store key and value to memory + struct ne_tensor* Vcur = ne_transpose(ctx0, ne_mul_mat(ctx0, model.layers[il].c_attn_v_proj_w, cur)); + + struct ne_tensor* k = ne_view_1d(ctx0, model.memory_k, N * n_embd, + (ne_element_size(model.memory_k) * n_embd) * (il * n_ctx + n_past)); + struct ne_tensor* v = + ne_view_2d(ctx0, model.memory_v, N, n_embd, (n_ctx)*ne_element_size(model.memory_v), + (il * n_ctx) * ne_element_size(model.memory_v) * n_embd + n_past * ne_element_size(model.memory_v)); + + ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur, k)); + ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur, v)); + + // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) + struct ne_tensor* Q = ne_permute(ctx0, Qcur, 0, 2, 1, 3); + + // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) + struct ne_tensor* K = ne_permute(ctx0, + ne_reshape_3d(ctx0, + ne_view_1d(ctx0, model.memory_k, (n_past + N) * n_embd, + il * n_ctx * ne_element_size(model.memory_k) * n_embd), + n_embd / n_head, n_head, n_past + N), + 0, 2, 1, 3); + + // K * Q + struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + struct ne_tensor* KQ_scaled = ne_scale_inplace(ctx0, KQ, ne_new_f32(ctx0, 1.0f / sqrt(float(n_embd) / n_head))); + + // KQ_masked = mask_past(KQ_scaled) + struct ne_tensor* KQ_masked = ne_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + + // KQ = soft_max(KQ_masked) + struct ne_tensor* KQ_soft_max = ne_soft_max_inplace(ctx0, KQ_masked); + + // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() + struct ne_tensor* V = + ne_view_3d(ctx0, model.memory_v, n_past + N, n_embd / n_head, n_head, n_ctx * ne_element_size(model.memory_v), + n_ctx * ne_element_size(model.memory_v) * n_embd / n_head, + il * n_ctx * ne_element_size(model.memory_v) * n_embd); + + // KQV = transpose(V) * KQ_soft_max + struct ne_tensor* KQV = ne_mul_mat(ctx0, V, KQ_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_embd, N) + cur = ne_cpy(ctx0, KQV_merged, d_ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N)); + + // projection (no bias) + cur = ne_mul_mat(ctx0, model.layers[il].c_attn_proj_w, cur); + + struct ne_tensor* inpFF = cur; + + // feed-forward network + // this is independent of the self-attention result, so it could be done in parallel to the self-attention + // note here we pass inpSA instead of cur + cur = ne_mul_mat(ctx0, model.layers[il].c_mlp_fc_w, inpSA); + + cur = ne_add(ctx0, ne_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur), cur); + + // GELU activation + cur = ne_gelu(ctx0, cur); + + // projection + // cur = proj_w*cur + proj_b + cur = ne_mul_mat(ctx0, model.layers[il].c_mlp_proj_w, cur); + + cur = ne_add(ctx0, ne_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur), cur); + + // self-attention + FF + cur = ne_add(ctx0, cur, inpFF); + + // input for next layer + inpL = ne_add(ctx0, cur, inpL); + } + // norm + inpL = ne_norm(ctx0, inpL); + + // inpL = ln_f_g*inpL + ln_f_b + inpL = ne_add(ctx0, ne_mul(ctx0, ne_repeat(ctx0, model.ln_f_g, inpL), inpL), ne_repeat(ctx0, model.ln_f_b, inpL)); + + // lm_head + inpL = ne_mul_mat(ctx0, model.lmh_g, inpL); + + inpL = ne_add(ctx0, ne_repeat(ctx0, model.lmh_b, inpL), inpL); + + // logits -> probs + // inpL = ne_soft_max_inplace(ctx0, inpL); + + // run the computation + ne_build_forward_expand(&gf, inpL); + ne_graph_compute(ctx0, &gf); + + // return result for just the last token + embd_w.resize(n_vocab); + memcpy(embd_w.data(), (float*)ne_get_data(inpL) + (n_vocab * (N - 1)), sizeof(float) * n_vocab); + + if (mem_per_token == 0) { + mem_per_token = ne_used_mem(ctx0) / N; + } + // printf("used_mem = %zu\n", ne_used_mem(ctx0)); + + ne_free(ctx0); + + return true; +} + +extern "C" { +void* init_gptj(int seed, int n_predict, int top_k, float top_p, float temp, float repeat_penalty, bool perplexity, + int n_ctx, const char* model_file) { + ne_time_init(); + + const int64_t t_main_start_us = ne_time_us(); + + common_params params; + params.model = "models/gpt-j-6B/ggml-model.bin"; + + params.seed = seed; + params.n_predict = n_predict; + params.top_k = top_k; + params.top_p = top_p; + params.temp = temp; + params.repeat_penalty = repeat_penalty; + params.perplexity = perplexity; + params.n_ctx = n_ctx; + params.model = model_file; + + if (params.seed < 0) { + params.seed = time(NULL); + } + + printf("%s: seed = %d\n", __func__, params.seed); + + int64_t t_load_us = 0; + + gptj_all* all = new gptj_all; + all->vocab = new gpt_vocab; + all->model = new gptj_model; + + // load the model + { + const int64_t t_start_us = ne_time_us(); + + if (!gptj_model_load(params.model, *(all->model), *(all->vocab))) { + fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); + return nullptr; + } + + t_load_us = ne_time_us() - t_start_us; + } + return (void*)(all); +} + +int32_t* eval_gptj_ids(void* all, int32_t* embd_inp_ptr, int ind_size, int n_predict, int top_k, float top_p, + float temp, int n_batch) { + gptj_all* all_in_one = (gptj_all*)all; + gpt_vocab* vocab_ptr = (all_in_one->vocab); + gptj_model* model_ptr = (all_in_one->model); + std::mt19937 rng(1234); + + int n_past = 0; + + std::vector logits; + std::vector embd_inp(embd_inp_ptr, embd_inp_ptr + ind_size); + n_predict = std::min(n_predict, model_ptr->hparams.n_ctx - (int)embd_inp.size()); + std::vector res; + std::vector embd; + + // determine the required inference memory per token: + size_t mem_per_token = 0; + gptj_eval(*model_ptr, N_thread, 0, {0, 1, 2, 3}, logits, mem_per_token); + + for (int i = embd.size(); i < embd_inp.size() + n_predict; i++) { + // predict + if (embd.size() > 0) { + if (!gptj_eval(*model_ptr, N_thread, n_past, embd, logits, mem_per_token)) { + printf("Failed to predict\n"); + return {}; + } + } + + n_past += embd.size(); + embd.clear(); + + if (i >= embd_inp.size()) { + const int n_vocab = model_ptr->hparams.n_vocab; + gpt_vocab::id id = 0; + id = gpt_sample_top_k_top_p(*vocab_ptr, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng); + // add it to the context + embd.push_back(id); + } else { + // if here, it means we are still processing the input prompt + for (int k = i; k < embd_inp.size(); k++) { + embd.push_back(embd_inp[k]); + if (embd.size() > n_batch) { + break; + } + } + i += embd.size() - 1; + } + res.insert(res.end(), embd.begin(), embd.end()); + + // end of text token + if (embd.back() == 50256) { + break; + } + } + int32_t* res_ptr = new int32_t[res.size()]; + std::copy(res.begin(), res.end(), res_ptr); + return res_ptr; +} + +char* eval_gptj_char(void* all, const char* prom, int n_predict, int top_k, float top_p, float temp, int n_batch) { + gptj_all* all_in_one = (gptj_all*)all; + gpt_vocab* vocab_ptr = (all_in_one->vocab); + gptj_model* model_ptr = (all_in_one->model); + std::string prompt(prom); + std::mt19937 rng(1234); + + int n_past = 0; + + std::vector logits; + + // tokenize the prompt + std::vector embd_inp = ::gpt_tokenize(*vocab_ptr, prompt); + n_predict = std::min(n_predict, model_ptr->hparams.n_ctx - (int)embd_inp.size()); + std::string res; + std::vector embd; + + // determine the required inference memory per token: + size_t mem_per_token = 0; + gptj_eval(*model_ptr, N_thread, 0, {0, 1, 2, 3}, logits, mem_per_token); + + for (int i = embd.size(); i < embd_inp.size() + n_predict; i++) { + // predict + if (embd.size() > 0) { + if (!gptj_eval(*model_ptr, N_thread, n_past, embd, logits, mem_per_token)) { + printf("Failed to predict\n"); + return NULL; + } + } + + n_past += embd.size(); + embd.clear(); + + if (i >= embd_inp.size()) { + const int n_vocab = model_ptr->hparams.n_vocab; + gpt_vocab::id id = 0; + id = gpt_sample_top_k_top_p(*vocab_ptr, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng); + // add it to the context + embd.push_back(id); + } else { + // if here, it means we are still processing the input prompt + for (int k = i; k < embd_inp.size(); k++) { + embd.push_back(embd_inp[k]); + if (embd.size() > n_batch) { + break; + } + } + i += embd.size() - 1; + } + for (auto id : embd) { + res += vocab_ptr->id_to_token[id]; + } + + // end of text token + if (embd.back() == 50256) { + break; + } + } + + char* res_c_str = new char[res.size() + 1]; + std::strcpy(res_c_str, res.c_str()); + // ne_free(model_ptr->ctx); + return res_c_str; +} + +void exit_gptj(void* all) { + gptj_all* gptj_in_all = (gptj_all*)all; + delete gptj_in_all; +} +} + +int main() { + auto gptj_in_all = init_gptj(1234, 32, 0, 1.0, 0.8, 1.02, false, 2048, "../ne-q4_0.bin"); + auto res = eval_gptj_char(gptj_in_all, "she opened the door and saw", 32, 0, 1.0, 0.8, 1); + std::cout << res << std::endl; + auto res1 = eval_gptj_char(gptj_in_all, "Once upon a time, there existed a little girl, who liked to have adventures. She wanted to go to places and meet new people, and have fun", 32, 0, 1.0, 0.8, 1); + std::cout << res1 << std::endl; + std::vector embd_inp = {7091,4721,262,3420,290,2497}; + auto res_ids = eval_gptj_ids(gptj_in_all, embd_inp.data(), embd_inp.size(), 32, 0, 1.0, 0.8, 1); + exit_gptj(gptj_in_all); + delete[] res; + delete[] res1; + delete[] res_ids; + return 0; +} diff --git a/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatGPTJ/quant_gptj.cpp b/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatGPTJ/quant_gptj.cpp new file mode 100644 index 00000000000..ea61e4a4ddf --- /dev/null +++ b/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatGPTJ/quant_gptj.cpp @@ -0,0 +1,194 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "application/common.h" +#include "core/ne.h" + +// default hparams (GPT-J 6B) +struct gptj_hparams { + int32_t n_vocab = 50400; + int32_t n_ctx = 2048; + int32_t n_embd = 4096; + int32_t n_head = 16; + int32_t n_layer = 28; + int32_t n_rot = 64; + int32_t ftype = 1; +}; + +// quantize a model +bool gptj_model_quantize(const std::string& fname_inp, const std::string& fname_out, ne_ftype ftype) { + gpt_vocab vocab; + + printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str()); + + auto finp = std::ifstream(fname_inp, std::ios::binary); + if (!finp) { + fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str()); + return false; + } + + auto fout = std::ofstream(fname_out, std::ios::binary); + if (!fout) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str()); + return false; + } + + // verify magic + { + uint32_t magic; + finp.read((char*)&magic, sizeof(magic)); + if (magic != NE_FILE_MAGIC) { + fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str()); + return false; + } + + fout.write((char*)&magic, sizeof(magic)); + } + + gptj_hparams hparams; + + // load hparams + { + finp.read((char*)&hparams.n_vocab, sizeof(hparams.n_vocab)); + finp.read((char*)&hparams.n_ctx, sizeof(hparams.n_ctx)); + finp.read((char*)&hparams.n_embd, sizeof(hparams.n_embd)); + finp.read((char*)&hparams.n_head, sizeof(hparams.n_head)); + finp.read((char*)&hparams.n_layer, sizeof(hparams.n_layer)); + finp.read((char*)&hparams.n_rot, sizeof(hparams.n_rot)); + finp.read((char*)&hparams.ftype, sizeof(hparams.ftype)); + + const int32_t qntvr_src = hparams.ftype / NE_QNT_VERSION_FACTOR; + const int32_t ftype_dst = NE_QNT_VERSION * NE_QNT_VERSION_FACTOR + ftype; + + printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx); + printf("%s: n_embd = %d\n", __func__, hparams.n_embd); + printf("%s: n_head = %d\n", __func__, hparams.n_head); + printf("%s: n_layer = %d\n", __func__, hparams.n_layer); + printf("%s: ftype (src) = %d\n", __func__, hparams.ftype); + printf("%s: qntvr (src) = %d\n", __func__, qntvr_src); + printf("%s: ftype (dst) = %d\n", __func__, ftype_dst); + printf("%s: qntvr (dst) = %d\n", __func__, NE_QNT_VERSION); + + fout.write((char*)&hparams.n_vocab, sizeof(hparams.n_vocab)); + fout.write((char*)&hparams.n_ctx, sizeof(hparams.n_ctx)); + fout.write((char*)&hparams.n_embd, sizeof(hparams.n_embd)); + fout.write((char*)&hparams.n_head, sizeof(hparams.n_head)); + fout.write((char*)&hparams.n_layer, sizeof(hparams.n_layer)); + fout.write((char*)&hparams.n_rot, sizeof(hparams.n_rot)); + fout.write((char*)&ftype_dst, sizeof(ftype_dst)); + } + + // load vocab + { + int32_t n_vocab = 0; + finp.read((char*)&n_vocab, sizeof(n_vocab)); + fout.write((char*)&n_vocab, sizeof(n_vocab)); + + if (n_vocab != hparams.n_vocab) { + fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", __func__, fname_inp.c_str(), n_vocab, + hparams.n_vocab); + return false; + } + + std::string word; + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + finp.read((char*)&len, sizeof(len)); + fout.write((char*)&len, sizeof(len)); + + word.resize(len); + finp.read((char*)word.data(), len); + fout.write((char*)word.data(), len); + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + } + } + + // regexes of tensor names to be quantized + const std::vector to_quant = { + ".*weight", + }; + + //if (!ne_common_quantize_0(finp, fout, ftype, to_quant, {"transformer.wte.weight"})) { + if (!ne_common_quantize_0(finp, fout, ftype, to_quant, {})) { + fprintf(stderr, "%s: failed to quantize model '%s'\n", __func__, fname_inp.c_str()); + return false; + } + + finp.close(); + fout.close(); + + return true; +} + +int main(int argc, char** argv) { + quant_params q_params; + if (quant_params_parse(argc, argv, q_params) == false) { + return 1; + } + const std::string fname_inp = q_params.model_file; + const std::string fname_out = q_params.out_file; + if (!isValidFilename(fname_inp)) { + fprintf(stderr, "invalid file names '%s'\n", fname_inp.c_str()); + return 1; + } + ne_ftype ftype = NE_FTYPE_MAP[ + std::make_tuple(q_params.bits, q_params.alg, q_params.block_size, q_params.scale_dtype, q_params.gemm_isa)]; + + // needed to initialize f16 tables + { + struct ne_init_params params = {0, NULL, false}; + struct ne_context* ctx = ne_init(params); + ne_free(ctx); + } + + const int64_t t_main_start_us = ne_time_us(); + + int64_t t_quantize_us = 0; + + // load the model + { + const int64_t t_start_us = ne_time_us(); + + if (!gptj_model_quantize(fname_inp, fname_out, ne_ftype(ftype))) { + fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str()); + return 1; + } + + t_quantize_us = ne_time_us() - t_start_us; + } + + // report timing + { + const int64_t t_main_end_us = ne_time_us(); + + printf("\n"); + printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us / 1000.0f); + printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f); + } + + return 0; +} diff --git a/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatLLAMA/quant_llama.cpp b/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatLLAMA/quant_llama.cpp index c9cbcb62619..20bda3c513a 100644 --- a/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatLLAMA/quant_llama.cpp +++ b/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatLLAMA/quant_llama.cpp @@ -38,21 +38,21 @@ struct MyHash { }; -static std::unordered_map, enum model_ftype, MyHash> +static std::unordered_map, enum ne_ftype, MyHash> NE_FTYPE_MAP = { // bits, alg, block size, scale dtype, gemm_isa -> ne_ftype - {{4, "sym", QK4_0, "fp32", "none"}, MODEL_FTYPE_MOSTLY_Q4_0}, - {{4, "asym", QK4_1, "fp32", "none"}, MODEL_FTYPE_MOSTLY_Q4_1}, - {{5, "sym", QK5_0, "fp32", "none"}, MODEL_FTYPE_MOSTLY_Q5_0}, - {{5, "asym", QK5_1, "fp32", "none"}, MODEL_FTYPE_MOSTLY_Q5_1}, - {{8, "sym", QK8_0, "fp32", "none"}, MODEL_FTYPE_MOSTLY_Q8_0}, - {{4, "sym", 32, "fp32", "amx"}, MODEL_FTYPE_MOSTLY_Q4_JBLAS_B32}, - {{4, "sym", 32, "bf16", "amx"}, MODEL_FTYPE_MOSTLY_Q4_JBLAS_BF16_B32}, - {{4, "sym", 128, "fp32", "amx"}, MODEL_FTYPE_MOSTLY_Q4_JBLAS_B128}, - {{4, "sym", -1024, "fp32", "amx"}, MODEL_FTYPE_MOSTLY_Q4_JBLAS_B128}, - {{4, "sym", 32, "fp32", "vnni"}, MODEL_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B32}, - {{4, "sym", 128, "fp32", "vnni"}, MODEL_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B128}, - {{4, "sym", 32, "bf16", "vnni"}, MODEL_FTYPE_MOSTLY_Q4_JBLAS_VNNI_BF16_B32}, + {{4, "sym", QK4_0, "fp32", "none"}, NE_FTYPE_MOSTLY_Q4_0}, + {{4, "asym", QK4_1, "fp32", "none"}, NE_FTYPE_MOSTLY_Q4_1}, + {{5, "sym", QK5_0, "fp32", "none"}, NE_FTYPE_MOSTLY_Q5_0}, + {{5, "asym", QK5_1, "fp32", "none"}, NE_FTYPE_MOSTLY_Q5_1}, + {{8, "sym", QK8_0, "fp32", "none"}, NE_FTYPE_MOSTLY_Q8_0}, + {{4, "sym", 32, "fp32", "amx"}, NE_FTYPE_MOSTLY_Q4_JBLAS_B32}, + {{4, "sym", 32, "bf16", "amx"}, NE_FTYPE_MOSTLY_Q4_JBLAS_BF16_B32}, + {{4, "sym", 128, "fp32", "amx"}, NE_FTYPE_MOSTLY_Q4_JBLAS_B128}, + {{4, "sym", -1024, "fp32", "amx"}, NE_FTYPE_MOSTLY_Q4_JBLAS_B128}, + {{4, "sym", 32, "fp32", "vnni"}, NE_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B32}, + {{4, "sym", 128, "fp32", "vnni"}, NE_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B128}, + {{4, "sym", 32, "bf16", "vnni"}, NE_FTYPE_MOSTLY_Q4_JBLAS_VNNI_BF16_B32}, }; struct quant_params { @@ -125,9 +125,9 @@ int main(int argc, char** argv) { } const std::string fname_inp = q_params.model_file; const std::string fname_out = q_params.out_file; - model_ftype ftype = NE_FTYPE_MAP[ + ne_ftype ftype = NE_FTYPE_MAP[ std::make_tuple(q_params.bits, q_params.alg, q_params.block_size, q_params.scale_dtype, q_params.gemm_isa)]; - printf("model_ftype: %d\n", ftype); + printf("ne_ftype: %d\n", ftype); const int nthread = q_params.nthread; const int64_t t_main_start_us = model_time_us(); diff --git a/intel_extension_for_transformers/backends/neural_engine/graph/application/common.cpp b/intel_extension_for_transformers/backends/neural_engine/graph/application/common.cpp index 2551322e9a0..221db01d4b4 100644 --- a/intel_extension_for_transformers/backends/neural_engine/graph/application/common.cpp +++ b/intel_extension_for_transformers/backends/neural_engine/graph/application/common.cpp @@ -700,6 +700,15 @@ bool ne_common_quantize_0(std::ifstream& finp, std::ofstream& fout, const ne_fty case NE_FTYPE_MOSTLY_Q8_0: qtype = NE_TYPE_Q8_0; break; + case NE_FTYPE_MOSTLY_Q4_JBLAS_B32: + case NE_FTYPE_MOSTLY_Q4_JBLAS_B128: + case NE_FTYPE_MOSTLY_Q4_JBLAS_B1024: + case NE_FTYPE_MOSTLY_Q4_JBLAS_BF16_B32: + case NE_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B32: + case NE_FTYPE_MOSTLY_Q4_JBLAS_VNNI_BF16_B32: + case NE_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B128: + qtype = NE_TYPE_Q4_JBLAS; + break; case NE_FTYPE_UNKNOWN: case NE_FTYPE_ALL_F32: case NE_FTYPE_MOSTLY_F16: diff --git a/intel_extension_for_transformers/backends/neural_engine/graph/core/data_types.h b/intel_extension_for_transformers/backends/neural_engine/graph/core/data_types.h index bf5d06447bc..bff901fb0ce 100644 --- a/intel_extension_for_transformers/backends/neural_engine/graph/core/data_types.h +++ b/intel_extension_for_transformers/backends/neural_engine/graph/core/data_types.h @@ -56,6 +56,13 @@ enum ne_ftype { NE_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors NE_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors NE_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors + NE_FTYPE_MOSTLY_Q4_JBLAS_B32 = 10, // except 1d tensors + NE_FTYPE_MOSTLY_Q4_JBLAS_B128 = 11, // except 1d tensors + NE_FTYPE_MOSTLY_Q4_JBLAS_B1024 = 12, // except 1d tensors + NE_FTYPE_MOSTLY_Q4_JBLAS_BF16_B32 = 13, // except 1d tensors + NE_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B32 = 14, // except 1d tensors + NE_FTYPE_MOSTLY_Q4_JBLAS_VNNI_BF16_B32 = 15, // except 1d tensors + NE_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B128 = 16, // except 1d tensors }; #define QK4_0 32 diff --git a/intel_extension_for_transformers/backends/neural_engine/graph/core/ne_layers.h b/intel_extension_for_transformers/backends/neural_engine/graph/core/ne_layers.h index f3dbc802019..65adb913dd0 100644 --- a/intel_extension_for_transformers/backends/neural_engine/graph/core/ne_layers.h +++ b/intel_extension_for_transformers/backends/neural_engine/graph/core/ne_layers.h @@ -119,11 +119,11 @@ NE_API struct ne_tensor* ne_new_tensor_3d(struct ne_context* ctx, enum ne_type t NE_API struct ne_tensor* ne_new_tensor_4d(struct ne_context* ctx, enum ne_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, size_t size); -#define d_ne_new_tensor(...) ne_new_tensor(##__VA_ARGS__, NE_SIZE_CALC) -#define d_ne_new_tensor_1d(...) ne_new_tensor_1d(##__VA_ARGS__, NE_SIZE_CALC) -#define d_ne_new_tensor_2d(...) ne_new_tensor_2d(##__VA_ARGS__, NE_SIZE_CALC) -#define d_ne_new_tensor_3d(...) ne_new_tensor_3d(##__VA_ARGS__, NE_SIZE_CALC) -#define d_ne_new_tensor_4d(...) ne_new_tensor_4d(##__VA_ARGS__, NE_SIZE_CALC) +#define d_ne_new_tensor(...) ne_new_tensor(__VA_ARGS__,NE_SIZE_CALC) +#define d_ne_new_tensor_1d(...) ne_new_tensor_1d(__VA_ARGS__, NE_SIZE_CALC) +#define d_ne_new_tensor_2d(...) ne_new_tensor_2d(__VA_ARGS__, NE_SIZE_CALC) +#define d_ne_new_tensor_3d(...) ne_new_tensor_3d(__VA_ARGS__, NE_SIZE_CALC) +#define d_ne_new_tensor_4d(...) ne_new_tensor_4d(__VA_ARGS__, NE_SIZE_CALC) NE_API struct ne_tensor* ne_new_i32(struct ne_context* ctx, int32_t value); NE_API struct ne_tensor* ne_new_f32(struct ne_context* ctx, float value); diff --git a/intel_extension_for_transformers/backends/neural_engine/graph/models/llama/llama_model.cpp b/intel_extension_for_transformers/backends/neural_engine/graph/models/llama/llama_model.cpp index 81d457a1895..6825c71aac9 100644 --- a/intel_extension_for_transformers/backends/neural_engine/graph/models/llama/llama_model.cpp +++ b/intel_extension_for_transformers/backends/neural_engine/graph/models/llama/llama_model.cpp @@ -223,7 +223,7 @@ struct model_file_loader { hparams.n_head = file.read_u32(); hparams.n_layer = file.read_u32(); hparams.n_rot = file.read_u32(); - hparams.ftype = (enum model_ftype)file.read_u32(); + hparams.ftype = (enum ne_ftype)file.read_u32(); } void read_vocab() { vocab.id_to_token.resize(hparams.n_vocab); @@ -304,7 +304,7 @@ struct model_file_loader { struct model_file_saver { model_file file; model_file_loader* any_file_loader; - model_file_saver(const char* fname, model_file_loader* any_file_loader, enum model_ftype new_ftype) + model_file_saver(const char* fname, model_file_loader* any_file_loader, enum ne_ftype new_ftype) : file(fname, "wb"), any_file_loader(any_file_loader) { fprintf(stderr, "model.cpp: saving model to %s\n", fname); write_magic(); @@ -315,7 +315,7 @@ struct model_file_saver { file.write_u32(MODEL_FILE_MAGIC); // magic file.write_u32(MODEL_FILE_VERSION); // version } - void write_hparams(enum model_ftype new_ftype) { + void write_hparams(enum ne_ftype new_ftype) { const model_hparams& hparams = any_file_loader->hparams; file.write_u32(hparams.n_vocab); file.write_u32(hparams.n_embd); @@ -650,23 +650,23 @@ static const char* model_file_version_name(model_file_version version) { return "unknown"; } -static const char* model_ftype_name(enum model_ftype ftype) { +static const char* ne_ftype_name(enum ne_ftype ftype) { switch (ftype) { - case MODEL_FTYPE_ALL_F32: + case NE_FTYPE_ALL_F32: return "all F32"; - case MODEL_FTYPE_MOSTLY_F16: + case NE_FTYPE_MOSTLY_F16: return "mostly F16"; - case MODEL_FTYPE_MOSTLY_Q4_0: + case NE_FTYPE_MOSTLY_Q4_0: return "mostly Q4_0"; - case MODEL_FTYPE_MOSTLY_Q4_1: + case NE_FTYPE_MOSTLY_Q4_1: return "mostly Q4_1"; - case MODEL_FTYPE_MOSTLY_Q4_1_SOME_F16: + case NE_FTYPE_MOSTLY_Q4_1_SOME_F16: return "mostly Q4_1, some F16"; - case MODEL_FTYPE_MOSTLY_Q5_0: + case NE_FTYPE_MOSTLY_Q5_0: return "mostly Q5_0"; - case MODEL_FTYPE_MOSTLY_Q5_1: + case NE_FTYPE_MOSTLY_Q5_1: return "mostly Q5_1"; - case MODEL_FTYPE_MOSTLY_Q8_0: + case NE_FTYPE_MOSTLY_Q8_0: return "mostly Q8_0"; default: return "unknown, may not work"; @@ -730,22 +730,22 @@ static void model_model_load_internal(const std::string& fname, model_context& l fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head); fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer); fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); - fprintf(stderr, "%s: ftype = %u (%s)\n", __func__, hparams.ftype, model_ftype_name(hparams.ftype)); + fprintf(stderr, "%s: ftype = %u (%s)\n", __func__, hparams.ftype, ne_ftype_name(hparams.ftype)); fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); fprintf(stderr, "%s: n_parts = %zu\n", __func__, ml->file_loaders.size()); fprintf(stderr, "%s: model size = %s\n", __func__, model_model_type_name(model.type)); } if (file_version < MODEL_FILE_VERSION_GGJT_V2) { - if (hparams.ftype != MODEL_FTYPE_ALL_F32 && hparams.ftype != MODEL_FTYPE_MOSTLY_F16 && - hparams.ftype != MODEL_FTYPE_MOSTLY_Q8_0) { + if (hparams.ftype != NE_FTYPE_ALL_F32 && hparams.ftype != NE_FTYPE_MOSTLY_F16 && + hparams.ftype != NE_FTYPE_MOSTLY_Q8_0) { throw format("this format is no longer supported (see https://github.com/ggerganov/model.cpp/pull/1405)"); } } if (file_version < MODEL_FILE_VERSION_GGJT_V3) { - if (hparams.ftype == MODEL_FTYPE_MOSTLY_Q4_0 || hparams.ftype == MODEL_FTYPE_MOSTLY_Q4_1 || - hparams.ftype == MODEL_FTYPE_MOSTLY_Q8_0) { + if (hparams.ftype == NE_FTYPE_MOSTLY_Q4_0 || hparams.ftype == NE_FTYPE_MOSTLY_Q4_1 || + hparams.ftype == NE_FTYPE_MOSTLY_Q8_0) { throw format("this format is no longer supported (see https://github.com/ggerganov/model.cpp/pull/1508)"); } } @@ -1453,31 +1453,31 @@ model_token model_sample_token(struct model_context* ctx, model_token_data_array // static void model_model_quantize_internal(const std::string& fname_inp, const std::string& fname_out, - enum model_ftype ftype, int nthread) { + enum ne_ftype ftype, int nthread) { ne_type quantized_type; switch (ftype) { - case MODEL_FTYPE_MOSTLY_Q4_0: + case NE_FTYPE_MOSTLY_Q4_0: quantized_type = NE_TYPE_Q4_0; break; - case MODEL_FTYPE_MOSTLY_Q4_1: + case NE_FTYPE_MOSTLY_Q4_1: quantized_type = NE_TYPE_Q4_1; break; - case MODEL_FTYPE_MOSTLY_Q5_0: + case NE_FTYPE_MOSTLY_Q5_0: quantized_type = NE_TYPE_Q5_0; break; - case MODEL_FTYPE_MOSTLY_Q5_1: + case NE_FTYPE_MOSTLY_Q5_1: quantized_type = NE_TYPE_Q5_1; break; - case MODEL_FTYPE_MOSTLY_Q8_0: + case NE_FTYPE_MOSTLY_Q8_0: quantized_type = NE_TYPE_Q8_0; break; - case MODEL_FTYPE_MOSTLY_Q4_JBLAS_B32: - case MODEL_FTYPE_MOSTLY_Q4_JBLAS_B128: - case MODEL_FTYPE_MOSTLY_Q4_JBLAS_B1024: - case MODEL_FTYPE_MOSTLY_Q4_JBLAS_BF16_B32: - case MODEL_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B32: - case MODEL_FTYPE_MOSTLY_Q4_JBLAS_VNNI_BF16_B32: - case MODEL_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B128: + case NE_FTYPE_MOSTLY_Q4_JBLAS_B32: + case NE_FTYPE_MOSTLY_Q4_JBLAS_B128: + case NE_FTYPE_MOSTLY_Q4_JBLAS_B1024: + case NE_FTYPE_MOSTLY_Q4_JBLAS_BF16_B32: + case NE_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B32: + case NE_FTYPE_MOSTLY_Q4_JBLAS_VNNI_BF16_B32: + case NE_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B128: quantized_type = NE_TYPE_Q4_JBLAS; break; default: @@ -1566,31 +1566,31 @@ static void model_model_quantize_internal(const std::string& fname_inp, const st jblas::prologue::PackedWeight* packedw = NULL; int blocksize = 32; auto type = CompType::S4_F32; - if (ftype == MODEL_FTYPE_MOSTLY_Q4_JBLAS_B32) { + if (ftype == NE_FTYPE_MOSTLY_Q4_JBLAS_B32) { blocksize = 32; type = CompType::S4_F32; - } else if (ftype == MODEL_FTYPE_MOSTLY_Q4_JBLAS_B128) { + } else if (ftype == NE_FTYPE_MOSTLY_Q4_JBLAS_B128) { blocksize = 128; type = CompType::S4_F32; - } else if (ftype == MODEL_FTYPE_MOSTLY_Q4_JBLAS_B1024) { + } else if (ftype == NE_FTYPE_MOSTLY_Q4_JBLAS_B1024) { blocksize = 1024; type = CompType::S4_F32; - } else if (ftype == MODEL_FTYPE_MOSTLY_Q4_JBLAS_BF16_B32) { + } else if (ftype == NE_FTYPE_MOSTLY_Q4_JBLAS_BF16_B32) { blocksize = 32; type = CompType::S4_Bf16; - } else if (ftype == MODEL_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B32) { + } else if (ftype == NE_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B32) { blocksize = 32; type = CompType::S4_F32; - } else if (ftype == MODEL_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B128) { + } else if (ftype == NE_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B128) { blocksize = 128; type = CompType::S4_F32; - } else if (ftype == MODEL_FTYPE_MOSTLY_Q4_JBLAS_VNNI_BF16_B32) { + } else if (ftype == NE_FTYPE_MOSTLY_Q4_JBLAS_VNNI_BF16_B32) { blocksize = 32; type = CompType::S4_Bf16; } auto cd = jblas::utils::parallel::CpuDevice::getInstance(); - if (ftype == MODEL_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B32 || ftype == MODEL_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B128 || - ftype == MODEL_FTYPE_MOSTLY_Q4_JBLAS_VNNI_BF16_B32) { + if (ftype == NE_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B32 || ftype == NE_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B128 || + ftype == NE_FTYPE_MOSTLY_Q4_JBLAS_VNNI_BF16_B32) { if (cd->AVX512F()) { packedw = vnnikernel.getWeightPtr()->compressWeightTranspose(n_, k_, (float*)tensor.data, k_, blocksize, type); @@ -1783,7 +1783,7 @@ struct model_context* model_init_from_file(const char* path_model, struct model_ void model_free(struct model_context* ctx) { delete ctx; } -int model_model_quantize(const char* fname_inp, const char* fname_out, enum model_ftype ftype, int nthread) { +int model_model_quantize(const char* fname_inp, const char* fname_out, enum ne_ftype ftype, int nthread) { try { model_model_quantize_internal(fname_inp, fname_out, ftype, nthread); return 0; diff --git a/intel_extension_for_transformers/backends/neural_engine/graph/models/llama/llama_model.h b/intel_extension_for_transformers/backends/neural_engine/graph/models/llama/llama_model.h index 34263bbff18..937b1a336ce 100644 --- a/intel_extension_for_transformers/backends/neural_engine/graph/models/llama/llama_model.h +++ b/intel_extension_for_transformers/backends/neural_engine/graph/models/llama/llama_model.h @@ -73,7 +73,7 @@ MODEL_API void model_free(struct model_context* ctx); // TODO: not great API - very likely to change // Returns 0 on success // nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given -MODEL_API int model_model_quantize(const char* fname_inp, const char* fname_out, model_ftype ftype, int nthread); +MODEL_API int model_model_quantize(const char* fname_inp, const char* fname_out, ne_ftype ftype, int nthread); // Apply a LoRA adapter to a loaded model // path_base_model is the path to a higher quality model to use as a base for diff --git a/intel_extension_for_transformers/backends/neural_engine/graph/models/llama/model_types.h b/intel_extension_for_transformers/backends/neural_engine/graph/models/llama/model_types.h index 7277c1dd091..3fd98141d59 100644 --- a/intel_extension_for_transformers/backends/neural_engine/graph/models/llama/model_types.h +++ b/intel_extension_for_transformers/backends/neural_engine/graph/models/llama/model_types.h @@ -132,27 +132,6 @@ static const std::map& MEM_REQ_EVAL() { return k_sizes; } -// model file types -enum model_ftype { - MODEL_FTYPE_ALL_F32 = 0, - MODEL_FTYPE_MOSTLY_F16 = 1, // except 1d tensors - MODEL_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors - MODEL_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors - MODEL_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 - // MODEL_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed - // MODEL_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed - MODEL_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors - MODEL_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors - MODEL_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors - MODEL_FTYPE_MOSTLY_Q4_JBLAS_B32 = 10, // except 1d tensors - MODEL_FTYPE_MOSTLY_Q4_JBLAS_B128 = 11, // except 1d tensors - MODEL_FTYPE_MOSTLY_Q4_JBLAS_B1024 = 12, // except 1d tensors - MODEL_FTYPE_MOSTLY_Q4_JBLAS_BF16_B32 = 13, // except 1d tensors - MODEL_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B32 = 14, // except 1d tensors - MODEL_FTYPE_MOSTLY_Q4_JBLAS_VNNI_BF16_B32 = 15, // except 1d tensors - MODEL_FTYPE_MOSTLY_Q4_JBLAS_VNNI_B128 = 16, // except 1d tensors -}; - enum model_file_version { MODEL_FILE_VERSION_NE, MODEL_FILE_VERSION_GGMF_V1, // added version field and scores in vocab @@ -176,7 +155,7 @@ struct model_hparams { uint32_t n_head = 32; uint32_t n_layer = 32; uint32_t n_rot = 64; - enum model_ftype ftype = MODEL_FTYPE_MOSTLY_F16; + enum ne_ftype ftype = NE_FTYPE_MOSTLY_F16; bool operator!=(const model_hparams& other) const { return static_cast(memcmp(this, &other, sizeof(model_hparams))); diff --git a/intel_extension_for_transformers/backends/neural_engine/graph/scripts/convert_gptj.py b/intel_extension_for_transformers/backends/neural_engine/graph/scripts/convert_gptj.py new file mode 100644 index 00000000000..dbd69b6a29b --- /dev/null +++ b/intel_extension_for_transformers/backends/neural_engine/graph/scripts/convert_gptj.py @@ -0,0 +1,152 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Convert Hugging Face fine-tuned gpt-neox-like models to ne format +# +# Usage: +# +# python3 models/convert-h5-to-ne.py +# +# This script is similar to "convert-pt-to-ne.py" +# + +import sys +import struct +import json +import torch +import numpy as np +from pathlib import Path +import argparse +from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List, + Literal, Optional, Sequence, Tuple, TypeVar, Union) +from transformers import GPTJForCausalLM + +# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + +def main(args_in: Optional[List[str]] = None) -> None: + parser = argparse.ArgumentParser(description="Convert a model to a NE compatible file") + parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)") + parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") + parser.add_argument("model", type=Path, help="directory containing model file") + args = parser.parse_args(args_in) + + dir_model = args.model.as_posix() + fname_out = args.outfile.as_posix() + + ftype = 0 + if args.outtype== "f16": + ftype = 1 + + # output in the same directory as the model + with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f: + encoder = json.load(f) + + with open(dir_model + "/added_tokens.json", "r", encoding="utf-8") as f: + encoder_added = json.load(f) + + with open(dir_model + "/config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + + print("Loading model: ", dir_model) + model = GPTJForCausalLM.from_pretrained(dir_model, low_cpu_mem_usage=True) + list_vars = model.state_dict() + fout = open(fname_out, "wb") + + fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex + fout.write(struct.pack("i", hparams["vocab_size"])) + fout.write(struct.pack("i", hparams["n_positions"])) + fout.write(struct.pack("i", hparams["n_embd"])) + fout.write(struct.pack("i", hparams["n_head"])) + fout.write(struct.pack("i", hparams["n_layer"])) + fout.write(struct.pack("i", hparams["rotary_dim"])) + fout.write(struct.pack("i", ftype)) + + byte_encoder = bytes_to_unicode() + byte_decoder = {v:k for k, v in byte_encoder.items()} + + fout.write(struct.pack("i", len(encoder) + len(encoder_added))) + + for key in encoder: + text = bytearray([byte_decoder[c] for c in key]) + fout.write(struct.pack("i", len(text))) + fout.write(text) + + for key in encoder_added: + text = bytearray([byte_decoder[c] for c in key]) + fout.write(struct.pack("i", len(text))) + fout.write(text) + + for name in list_vars.keys(): + data = list_vars[name].squeeze().numpy() + print("Processing variable: " + name + " with shape: ", data.shape) + + # we don't need these + if name.endswith("attn.masked_bias") or name.endswith(".attn.bias"): + print(" Skipping variable: " + name) + continue + + n_dims = len(data.shape); + + # ftype == 0 -> float32, ftype == 1 -> float16 + ftype_cur = 0; + if ftype != 0: + if name[-7:] == ".weight" and n_dims == 2: + print(" Converting to float16") + data = data.astype(np.float16) + ftype_cur = 1 + else: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + else: + if data.dtype != np.float32: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + + str = name.encode('utf-8') + fout.write(struct.pack("iii", n_dims, len(str), ftype_cur)) + for i in range(n_dims): + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + fout.write(str); + + # data + data.tofile(fout) + + fout.close() + + print("Done. Output file: " + fname_out) + print("") + +if __name__ == '__main__': + main() diff --git a/intel_extension_for_transformers/backends/neural_engine/graph/scripts/gptj_binding.py b/intel_extension_for_transformers/backends/neural_engine/graph/scripts/gptj_binding.py new file mode 100644 index 00000000000..ca0cffc4de0 --- /dev/null +++ b/intel_extension_for_transformers/backends/neural_engine/graph/scripts/gptj_binding.py @@ -0,0 +1,45 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ctypes import * +import numpy as np + +lib = cdll.LoadLibrary('./lib/libGptjPyBind.so') + +init_gptj = lib.init_gptj +init_gptj.argtypes = [c_int, c_int, c_int, c_float, c_float, c_float, c_bool, c_int, c_char_p] +init_gptj.restype = c_void_p + +gptj_in_all = init_gptj(1234, 32, 0, 1.0, 0.8, 1.5, False, 2048, b"../ne-q4_0.bin") + +eval_gptj_char = lib.eval_gptj_char +eval_gptj_char.argtypes = [c_void_p, c_char_p, c_int, c_int, c_float, c_float, c_int] +eval_gptj_char.restype = c_char_p + +#res = eval_gptj_char(gptj_in_all, b"she opened the door and saw", 32, 0, 1.0, 0.8, 1) + +eval_gptj_ids = lib.eval_gptj_ids +eval_gptj_ids.argtypes = [c_void_p, np.ctypeslib.ndpointer(dtype=np.int32, ndim=1, flags='C_CONTIGUOUS'), c_int, c_int, c_int, c_float, c_float, c_int] +eval_gptj_ids.restype = np.ctypeslib.ndpointer(dtype=np.int32, ndim=1, flags='C_CONTIGUOUS') + +#res = eval_gptj_ids(gptj_in_all, np.array([7091, 4721, 262, 3420, 290, 2497], dtype=np.int32), 6, 32, 0, 1.0, 0.8, 1) +res = eval_gptj_ids(gptj_in_all, np.array([7454, 2402, 257, 640, 11, 612, 11196, 257, 1310, 2576, 11, 508, 8288, 284, 423, 17545, 13, 1375, 2227, 284, 467, 284, 4113, 290, 1826, 649, 661, 11, 290, 423, 1257], dtype=np.int32), 31, 32, 0, 1.0, 0.8, 1) + +ctypes_pntr = cast(res, POINTER(c_int)) +res_np = np.ctypeslib.as_array(ctypes_pntr, shape=(31,)) +exit_gptj = lib.exit_gptj +exit_gptj.argtypes = [c_void_p] +exit_gptj.restype = None + +exit_gptj(gptj_in_all)