From 54ef7ce4d12b38600b3a05ef757c95e2001ffec3 Mon Sep 17 00:00:00 2001 From: "Ding, Yi1" Date: Fri, 6 Oct 2023 14:25:38 +0800 Subject: [PATCH 01/11] add ppl --- .gitignore | 8 + .../llm/runtime/graph/__init__.py | 39 ++- .../runtime/graph/application/main_pybind.cpp | 167 +++++++---- .../llm/runtime/graph/requirements.txt | 3 +- .../llm/runtime/graph/scripts/perplexity.py | 271 ++++++++++++++++++ 5 files changed, 409 insertions(+), 79 deletions(-) create mode 100644 intel_extension_for_transformers/llm/runtime/graph/scripts/perplexity.py diff --git a/.gitignore b/.gitignore index 105a2c6a5fc..62b3e04a79e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,8 @@ +/intel_extension_for_transformers/llm/runtime/graph/* +!/intel_extension_for_transformers/llm/runtime/graph/*.* +!/intel_extension_for_transformers/llm/runtime/graph/*/ +### ignore binary files in llm-runtime ### + *.pyc .vscode .idea @@ -11,6 +16,7 @@ *.log *.swp *.onnx +*.bin tags build/ _build @@ -32,6 +38,8 @@ CMakeUserPresets.json /intel_extension_for_transformers/llm/runtime/.vs /intel_extension_for_transformers/llm/runtime/out +/intel_extension_for_transformers/llm/runtime/graph/out +/intel_extension_for_transformers/llm/runtime/graph/runtime_outs /examples/**/*.npy /examples/**/*.bin /examples/**/*.yaml diff --git a/intel_extension_for_transformers/llm/runtime/graph/__init__.py b/intel_extension_for_transformers/llm/runtime/graph/__init__.py index 579ef7b511c..9433ed7054c 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/__init__.py +++ b/intel_extension_for_transformers/llm/runtime/graph/__init__.py @@ -15,11 +15,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from transformers import AutoConfig, AutoTokenizer -from intel_extension_for_transformers.llm.runtime.graph.scripts.convert import convert_model + import torch +from intel_extension_for_transformers.llm.runtime.graph.scripts.convert import convert_model +from transformers import AutoConfig, AutoTokenizer + model_maps = {"gpt_neox": "gptneox", "gpt_bigcode": "starcoder"} + class Model: def __init__(self): self.module = None @@ -73,8 +76,7 @@ def init(self, model_name, not_quant=False, use_cache=False, **quant_kwargs): # check cache and quantization output_path = "runtime_outs" - if not os.path.exists(output_path): - os.makedirs(output_path) + os.makedirs(output_path, exist_ok=True) fp32_bin = "{}/ne_{}_f32.bin".format(output_path, model_type) quant_bin = "{}/ne_{}_q.bin".format(output_path, model_type) @@ -91,9 +93,8 @@ def init(self, model_name, not_quant=False, use_cache=False, **quant_kwargs): if not_quant: print("FP32 model will be used.") return - self.module.Model.quant_model(model_path = fp32_bin, out_path = quant_bin, **quant_kwargs) + self.module.Model.quant_model(model_path=fp32_bin, out_path=quant_bin, **quant_kwargs) assert os.path.exists(quant_bin), "Fail to quantize model" - # clean os.remove(fp32_bin) @@ -110,9 +111,7 @@ def init_from_bin(self, model_name, model_path, **generate_kwargs): def quant_model(self, model_name, model_path, out_path, **quant_kwargs): self.__import_package(model_name) - self.module.Model.quant_model(model_path = model_path, - out_path = out_path, **quant_kwargs) - + self.module.Model.quant_model(model_path=model_path, out_path=out_path, **quant_kwargs) def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=False, stopping_criteria=None, **generate_kwargs): max_new_tokens = generate_kwargs.get("max_new_tokens", -1) @@ -129,8 +128,7 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa ret = input_ids.tolist() beam_search = False - if ("num_beams" in generate_kwargs and generate_kwargs["num_beams"] > 1) and not \ - generate_kwargs.get("do_sample", False): + if (generate_kwargs.get("num_beams", 1) > 1) and not generate_kwargs.get("do_sample", False): beam_search = True if not beam_search: # TODO support multi batch @@ -142,12 +140,14 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa Make sure that `num_beams` is set to 1." if self.generate_round == 0 and not ignore_prompt: streamer.put(input_ids) - + if interactive: self.model.reset_token_end() out_count = 0 + input_list = input_ids.tolist() while True: - response = self.model.generate(input_ids = input_ids.tolist()) + response = self.model.generate(input_ids=input_list) + input_list = [] # next-token stage will use previous output if len(response) == 0: break if streamer: @@ -158,14 +158,23 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa if stopping_criteria(torch.tensor(ret), None): break elif ret[0][-1] == self.tokenizer.eos_token_id or \ - (max_new_tokens != -1 and out_count > max_new_tokens): + (max_new_tokens != -1 and out_count > max_new_tokens): break out_count += 1 if streamer: streamer.end() - + self.generate_round += 1 return ret def is_token_end(self): return self.model.is_token_end() + + def __call__(self, input_ids, reinit=False, **kwargs): + if self.model is None: + self.init_from_bin(self.model_type, self.bin_file, **kwargs) + self.generate_round = 0 + elif reinit: + self.model.reinit() + self.generate_round = 0 + return self.model.evaluate(input_ids.tolist()) diff --git a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp index 467fce54e9f..294841596cc 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp @@ -58,14 +58,23 @@ class Model { void init_model(const std::string& model_path, int n_predict, int n_batch, int ctx_size, int seed, int threads, float repetition_penalty, int num_beams, bool do_sample, int top_k, float top_p, float temperature, int min_new_tokens, float length_penalty, bool early_stopping, int n_keep, int n_discard, - bool shift_roped_k, int batch_size, model_vocab::id pad_token); + bool shift_roped_k, int batch_size, model_vocab::id pad_token, const std::string& memory_dtype); void reinit(); std::vector> generate(const std::vector>& input_ids); std::vector> generate_tokens(const std::vector>& input_ids); + const std::vector& evaluate_(const std::vector>& input_ids); + std::vector> evaluate(const std::vector>& input_ids) { + if (input_ids.size() != 1) { + fprintf(stderr, "\nERROR: only support batch == 1 input!\n"); + return {{}}; + } + const auto& logits = evaluate_(input_ids); + return {logits}; + } bool is_token_end() { return token_eos; } static int quant_model(const std::string& model_path, const std::string& out_path, const std::string& weight_dtype, const std::string& alg, int group_size, const std::string& scale_dtype, - const std::string& compute_dtype, bool use_ggml); + const std::string& compute_dtype, bool use_ggml, int threads); void reset_token_end() { token_eos = false; curr_input_ids.clear(); @@ -84,17 +93,19 @@ class Model { bool token_eos = false; long int generate_count = 0; - model_token post_process(float* logits); - model_token post_greedy_search(float* logits); + std::vector> beam_generate(const std::vector>& input_ids); + model_token post_process(const float* logits); + model_token post_greedy_search(const float* logits); std::vector> post_beam_search(model_context* lctx, const int& n_predict, const std::vector& inputs, const int& n_threads); - model_token post_sample_top_k_top_p_repeat(float* logits); + model_token post_sample_top_k_top_p_repeat(const float* logits); }; void Model::init_model(const std::string& model_path, int max_new_tokens, int n_batch, int ctx_size, int seed, int threads, float repetition_penalty, int num_beams, bool do_sample, int top_k, float top_p, float temperature, int min_new_tokens, float length_penalty, bool early_stopping, int n_keep, - int n_discard, bool shift_roped_k, int batch_size, model_vocab::id pad_token) { + int n_discard, bool shift_roped_k, int batch_size, model_vocab::id pad_token, + const std::string& memory_dtype) { #ifdef MODEL_NAME params.model_name = MODEL_NAME; #endif @@ -110,15 +121,21 @@ void Model::init_model(const std::string& model_path, int max_new_tokens, int n_ params.do_sample = do_sample; params.batch_size = batch_size; params.beam_search = (num_beams > 1 && !do_sample) ? true : false; - if (params.beam_search) { - params.memory_type = KV_MEM_TYPE_F16; // TODO NO MHA IN BEAM SEARCH - } params.top_k = top_k; params.top_p = top_p; params.temp = temperature; params.n_keep = n_keep; params.n_discard = n_discard; params.shift_roped_k = shift_roped_k; + if (memory_dtype == "f32") + params.memory_type = KV_MEM_TYPE_F32; + else if (memory_dtype == "f16") + params.memory_type = KV_MEM_TYPE_F16; + else if (memory_dtype == "auto") + params.memory_type = KV_MEM_TYPE_AUTO; + else + fprintf(stderr, "Unexpected memory dtype!"); + if (params.beam_search) params.memory_type = KV_MEM_TYPE_F16; // TODO(Yi): NO MHA IN BEAM SEARCH printf("beam_size: %d, do_sample: %d, top_k: %d, top_p: %f\n", params.beam_size, params.do_sample, params.top_k, params.top_p); @@ -149,56 +166,64 @@ void Model::reinit() { generate_count = 0; } -std::vector> Model::generate(const std::vector>& input_ids) { - int n_remain = params.n_predict; - std::vector> rets; - if (ctx->beam_search) { - MODEL_ASSERT(input_ids.size() == ctx->batch_size); - if (ctx->batch_size > 1 && ctx->vocab.pad_token_id == -1) { - fprintf(stderr, "\nERROR: please set pad_token for beam search multi-batch generation!\n"); - return rets; - } - std::vector inputs; - for (int bs = 0; bs < input_ids.size(); ++bs) { - uint32_t count = 0; - model_vocab::id pad_token_id = ctx->vocab.pad_token_id; - auto iter = std::find_if(input_ids[bs].begin(), input_ids[bs].end(), - [&pad_token_id](model_token t) { return (t != pad_token_id); }); - if (iter == input_ids[bs].end()) fprintf(stderr, "\nERROR: there are all pad tokens in batch %d!\n", bs); - count = std::distance(input_ids[bs].begin(), iter); - inputs.push_back(model_input{ - /*.tokens =*/input_ids[bs].data(), - /*.n_tokens =*/(uint32_t)input_ids[bs].size(), - /*.n_prompt_tokens =*/0, - /*.n_past =*/0, - /*.n_total =*/0, - /*.request_idx =*/bs, - /*.beam_idx =*/0, - /*.padding_side =*/0, - /*n_padding =*/count, - }); - } - return post_beam_search(ctx, n_remain, inputs, params.n_threads); +std::vector> Model::beam_generate(const std::vector>& input_ids) { + MODEL_ASSERT(input_ids.size() == ctx->batch_size); + if (ctx->batch_size > 1 && ctx->vocab.pad_token_id == -1) { + fprintf(stderr, "\nERROR: please set pad_token for beam search multi-batch generation!\n"); + return {{}}; + } + std::vector inputs; + for (int bs = 0; bs < input_ids.size(); ++bs) { + uint32_t count = 0; + model_vocab::id pad_token_id = ctx->vocab.pad_token_id; + auto iter = std::find_if(input_ids[bs].begin(), input_ids[bs].end(), + [&pad_token_id](model_token t) { return (t != pad_token_id); }); + if (iter == input_ids[bs].end()) fprintf(stderr, "\nERROR: there are all pad tokens in batch %d!\n", bs); + count = std::distance(input_ids[bs].begin(), iter); + inputs.push_back(model_input{ + /*.tokens =*/input_ids[bs].data(), + /*.n_tokens =*/(uint32_t)input_ids[bs].size(), + /*.n_prompt_tokens =*/0, + /*.n_past =*/0, + /*.n_total =*/0, + /*.request_idx =*/bs, + /*.beam_idx =*/0, + /*.padding_side =*/0, + /*n_padding =*/count, + }); } + return post_beam_search(ctx, params.n_predict, inputs, params.n_threads); +} + +const std::vector& Model::evaluate_(const std::vector>& input_ids) { + static const std::vector empty_ret{}; if (input_ids.size() > 1) { fprintf(stderr, "\nERROR: Only beam search supports multi-batch generation!\n"); - return rets; + return empty_ret; } - if (curr_input_ids.empty()) { - if (input_ids[0].size() > n_ctx - 4) { - fprintf(stderr, "\n%s: Warning: prompt is too long (%d tokens, max %d), will be truncated\n", __func__, - input_ids[0].size(), n_ctx - 4); - curr_input_ids.resize(n_ctx - 4); - std::copy(input_ids[0].end() - n_ctx - 4, input_ids[0].end(), curr_input_ids.begin()); - } else { - curr_input_ids = input_ids[0]; + + const auto& input_id0 = input_ids[0]; // currently only support single batch + if (input_id0.empty()) { // use internel input id + if (curr_input_ids.empty()) { + fprintf(stderr, "%s: error: no input\n", __func__); + return empty_ret; } + } else if (!curr_input_ids.empty()) { + fprintf(stderr, "%s: error: prompt confliction\n", __func__); + return empty_ret; + } else if (input_id0.size() > n_ctx - 4) { // long input_id0 and empty curr_input_ids + fprintf(stderr, "\n%s: Warning: prompt is too long (%d tokens, max %d), will be truncated\n", __func__, + input_id0.size(), n_ctx - 4); + curr_input_ids.resize(n_ctx - 4); + std::copy(input_id0.end() - n_ctx - 4, input_id0.end(), curr_input_ids.begin()); + } else { // good input_id0 and empty curr_input_ids + curr_input_ids = input_id0; } - for (auto item : curr_input_ids) { - last_n_tokens.erase(last_n_tokens.begin()); - last_n_tokens.push_back(item); - } + // push elements in curr_input_ids to the last_n_tokens queue + last_n_tokens.erase(last_n_tokens.begin(), last_n_tokens.begin() + curr_input_ids.size()); + last_n_tokens.insert(last_n_tokens.end(), curr_input_ids.begin(), curr_input_ids.end()); + // infinite text generation via context swapping if (n_past + curr_input_ids.size() > n_ctx) { // always keep the first token @@ -214,7 +239,8 @@ std::vector> Model::generate(const std::vector inputs = {model_input{ + + std::vector inputs{{ /*.tokens =*/curr_input_ids.data(), /*.n_tokens =*/(uint32_t)curr_input_ids.size(), /*.n_prompt_tokens =*/0, @@ -229,10 +255,22 @@ std::vector> Model::generate(const std::vectorlogits; +} +std::vector> Model::generate(const std::vector>& input_ids) { + if (ctx->beam_search) return beam_generate(input_ids); + if (input_ids.size() > 1) { + fprintf(stderr, "\nERROR: Only beam search supports multi-batch generation!\n"); + return {{}}; + } + + const auto& logits = evaluate_(input_ids); + if (logits.empty()) return {{}}; + + model_token next_token_id = post_process(logits.data()); + curr_input_ids = {next_token_id}; generate_count++; return {{next_token_id}}; } @@ -339,7 +377,7 @@ std::vector> Model::generate_tokens(const std::vector> Model::post_beam_search(model_context* lct } } -model_token Model::post_sample_top_k_top_p_repeat(float* logits) { +model_token Model::post_sample_top_k_top_p_repeat(const float* logits) { int alpha_frequency = 0; int alpha_presence = 0; int repeat_last_n = 64; @@ -392,7 +430,7 @@ model_token Model::post_sample_top_k_top_p_repeat(float* logits) { return id; } -model_token Model::post_process(float* logits) { +model_token Model::post_process(const float* logits) { assert(("Beam search does not support streaming.", params.beam_size == 1)); if (params.do_sample == false) { return post_greedy_search(logits); @@ -403,7 +441,7 @@ model_token Model::post_process(float* logits) { int Model::quant_model(const std::string& model_path, const std::string& out_path, const std::string& weight_dtype, const std::string& alg, int group_size, const std::string& scale_dtype, - const std::string& compute_dtype, bool use_ggml) { + const std::string& compute_dtype, bool use_ggml, int threads) { quant_params q_params; #ifdef MODEL_NAME q_params.model_name = MODEL_NAME; @@ -422,10 +460,10 @@ int Model::quant_model(const std::string& model_path, const std::string& out_pat q_params.scale_dtype = scale_dtype; q_params.compute_dtype = compute_dtype; q_params.use_ggml = use_ggml; + q_params.nthread = threads; ne_ftype ftype = quant_params_to_ftype(q_params); printf("ne_ftype: %d\n", ftype); - const int nthread = q_params.nthread; auto quant_layer = get_model_quant_layer(q_params.model_name); if (model_quantize(q_params, quant_layer)) { @@ -503,12 +541,15 @@ PYBIND11_MODULE(mistral_cpp, m) py::arg("do_sample") = false, py::arg("top_k") = 40, py::arg("top_p") = 0.95, py::arg("temperature") = 0.8, py::arg("min_new_tokens") = 0, py::arg("length_penalty") = 1.0, py::arg("early_stopping") = false, py::arg("n_keep") = 0, py::arg("n_discard") = -1, py::arg("shift_roped_k") = false, - py::arg("batch_size") = 1, py::arg("pad_token") = -1) + py::arg("batch_size") = 1, py::arg("pad_token") = -1, py::arg("memory_dtype") = "auto") .def("generate", &Model::generate, "Generate token with input ids", py::arg("input_ids")) + .def("evaluate", &Model::evaluate, "Evaluate token with input ids and output logits", + py::arg("input_ids") = std::vector{}) .def("generate_tokens", &Model::generate_tokens, "Generate tokens with input ids", py::arg("input_ids")) .def_static("quant_model", &Model::quant_model, "Quantize model", py::arg("model_path"), py::arg("out_path"), py::arg("weight_dtype") = "int4", py::arg("alg") = "sym", py::arg("group_size") = 32, - py::arg("scale_dtype") = "fp32", py::arg("compute_dtype") = "int8", py::arg("use_ggml") = false) + py::arg("scale_dtype") = "fp32", py::arg("compute_dtype") = "int8", py::arg("use_ggml") = false, + py::arg("threads") = 8) .def("is_token_end", &Model::is_token_end) .def("reset_token_end", &Model::reset_token_end) .def("reinit", &Model::reinit); diff --git a/intel_extension_for_transformers/llm/runtime/graph/requirements.txt b/intel_extension_for_transformers/llm/runtime/graph/requirements.txt index 5d6c0c222df..f61f49ecd8a 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/requirements.txt +++ b/intel_extension_for_transformers/llm/runtime/graph/requirements.txt @@ -5,4 +5,5 @@ sentencepiece protobuf<3.20 einops accelerate -peft \ No newline at end of file +peft +datasets diff --git a/intel_extension_for_transformers/llm/runtime/graph/scripts/perplexity.py b/intel_extension_for_transformers/llm/runtime/graph/scripts/perplexity.py new file mode 100644 index 00000000000..d3ef5519134 --- /dev/null +++ b/intel_extension_for_transformers/llm/runtime/graph/scripts/perplexity.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import math +import os +import pathlib +from typing import Dict, List + +import matplotlib.pyplot as plt +import torch +from tqdm import tqdm + +logging.basicConfig() +logger = logging.getLogger('perplexity') + +''' +Preparing test dataset: + +>>> import datasets +>>> dataset = datasets.load_dataset("wikitext", "wikitext-2-raw-v1", split='test', num_proc=16) +>>> dataset.save_to_disk('~/wikitext-2-raw-v1-data-test') +>>> dataset = datasets.load_dataset("pg19", split='test', num_proc=16) +>>> dataset.save_to_disk('~/pg19-data-test') +''' + + +def try_resolve_dir(d): + resolved = pathlib.Path(d).expanduser().resolve() + if resolved.exists(): + return str(resolved) + return d + + +def get_ppl(sum_nll, sum_nll2, cnt: int): + ''' Get ppl and its standard deviation from sum of negative log likelihood ''' + nll = sum_nll / cnt + nll2 = sum_nll2 / cnt + ppl = math.exp(nll) + return ppl, 0. if cnt <= 1 else math.sqrt((nll2 - nll * nll) / (cnt-1)) + + +def perplexity(model_name, dataset_name, **kwargs): + import datasets + from intel_extension_for_transformers.transformers import ( + AutoModelForCausalLM, WeightOnlyQuantConfig) + from transformers import AutoTokenizer, AutoConfig + model_name = try_resolve_dir(model_name) + dataset_name = try_resolve_dir(dataset_name) + + ctx_size = kwargs.get("ctx_size", 256) + prompt_size = kwargs.get("prompt_size", ctx_size // 4) # use one quarter as prompt + n_threads = kwargs.get("n_threads", len(os.sched_getaffinity(0))) # Note: linux only + n_pred_per_sample = kwargs.get("n_pred_per_sample", ctx_size * 2) + n_sampels = kwargs.get("n_sampels", 2) + data_text_concat = kwargs.get("data_text_concat", "wikitext-2-raw-v1" in dataset_name) # concat samples with `\n\n` + default_model_kwargs = {"batch_size": 256, "ctx_size": 256, "n_keep": 4} + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + data = datasets.load_from_disk(dataset_name) + test_text = data['text'] + if data_text_concat: + test_text = ['\n\n'.join(test_text)] + + if n_sampels < 0: + n_sampels = len(test_text) + elif n_sampels > len(test_text): + logger.warning(f"Try to eval {n_sampels} samples but there are only {len(test_text)} in the dataset!") + n_sampels = len(test_text) + + test_ids = [] + with tqdm(total=n_sampels, desc="tokenizing") as pbar: + length_needed = prompt_size + n_pred_per_sample + for text in test_text: + if len(test_ids) > n_sampels: + break + ids = tokenizer(text, return_tensors="pt", max_length=length_needed, truncation=True).input_ids + if ids.shape.numel() >= length_needed: + test_ids.append(ids) + pbar.update(1) + + del tokenizer + + vocab_size: int = AutoConfig.from_pretrained(model_name, trust_remote_code=True).vocab_size + + woq_kwargs = {k: kwargs[k] for k in kwargs if k in + ['use_cache', 'compute_dtype', 'weight_dtype', 'scale_dtype', 'group_size', 'use_ggml']} + model_kwargs = {k: kwargs[k] for k in kwargs if k in + ['n_keep', 'shift_roped_k', 'memory_dtype']} + model_kwargs = {**default_model_kwargs, **model_kwargs} + woq_config = WeightOnlyQuantConfig(**woq_kwargs) + model = AutoModelForCausalLM.from_pretrained( + model_name, quantization_config=woq_config, trust_remote_code=True, **model_kwargs) + + ppl_hist = [{} for _ in range(n_sampels)] # ppl_hist[i_sample][end_pos] = ppl + sum_nll = [0. for _ in range(n_sampels)] # sum of negative log likelyhood + sum_nll2 = [0. for _ in range(n_sampels)] # sum of nll square + + pbar = tqdm(range(n_pred_per_sample * n_sampels)) + for i in pbar: + i_sample = i // n_pred_per_sample + i_pred = i % n_pred_per_sample + + is_first = (i_pred == 0) + + begin_pos = 0 if is_first else i_pred + prompt_size - 1 + end_pos = i_pred + prompt_size + cur_input = test_ids[i_sample][:, begin_pos:end_pos] + cur_target: torch.Tensor = test_ids[i_sample][:, end_pos] + out = model(cur_input, threads=n_threads, reinit=is_first) + out = torch.tensor(out).reshape(-1, vocab_size) + logsoftmax = out.log_softmax(-1) + nll = logsoftmax.take_along_dim(cur_target.view(-1, 1), 1) + assert len(nll) == 1 + nll_v = -nll.flatten().tolist()[0] + sum_nll[i_sample] += nll_v + sum_nll2[i_sample] += nll_v * nll_v + + cur_ppl, cur_sd = get_ppl(sum_nll[i_sample], sum_nll2[i_sample], i_pred + 1) + msg = f"Sample {i_sample + 1} / {n_sampels}; PPL = {cur_ppl:.4f} +/- {cur_ppl * cur_sd:.5f}" + pbar.set_description(msg, False) + ppl_hist[i_sample][end_pos] = cur_ppl + + return ppl_hist + + +def add_log_ppl_line(ax: plt.Axes, ppl_data: List[Dict[int, float]], label="log PPL"): + """ Plot PPL and return xmax / ymax""" + xs = [] + ys = [] + max_pos = max(max(d.keys()) for d in ppl_data) + for i in range(max_pos + 1): + ppls = [d[i] for d in ppl_data if i in d] + if not ppls: + continue + xs.append(i) + ys.append(math.log(sum(ppls) / len(ppls))) # average over samples + ax.plot(xs, ys, label=label) + + xmax = xs[torch.argmax(torch.tensor(ys)).item()] + ymax = max(ys) + return xmax, ymax, xs, ys + + +def draw_ppl(img_path: str, ppl_data: List[Dict[int, float]], ctx_size: int, model_title: str): + fig, ax = plt.subplots() + xmax, ymax, _, _ = add_log_ppl_line(ax, ppl_data) + ax.annotate(f"max={ymax:.4f}", (xmax, ymax)) + + ctx_line = ax.axvline(ctx_size, linestyle='--', color='r') + ctx_line.set_label('KV Cache Size') + ax.set_xlabel('Context Length') + ax.set_ylabel('Log Perplexity') + ax.legend() + + ax.set_title(model_title) + fig.suptitle("Language modeling perplexity") + fig.savefig(img_path) + + print(f"Max PPL: {math.exp(ymax)}") + return fig + + +def add_quant_args(parser: argparse.ArgumentParser): + group = parser.add_argument_group('quantize config') + group.add_argument('--use_cache', + action="store_true", + help="Use local quantized model if file exists") + group.add_argument( + "--weight_dtype", + choices=["int4", "int8"], + help="Data type of quantized weight: int4/int8 (default: int4)", + default="int4", + ) + group.add_argument( + "--alg", + type=str, + help="Quantization algorithm to use: sym/asym (default: sym)", + default="sym", + ) + group.add_argument( + "--group_size", type=int, help="Group size: Int (default: 32)", default=32 + ) + group.add_argument( + "--scale_dtype", + type=str, + help="Data type of scales: bf16/fp32 (default: fp32)", + default="fp32", + ) + group.add_argument( + "--compute_dtype", + type=str, + help="Data type of Gemm computation: int8/bf16/fp32 (default: int8)", + default="int8", + ) + group.add_argument( + "--use_ggml", + action="store_true", + help="enable ggml for quantization and inference", + ) + return group + + +def add_run_args(parser: argparse.ArgumentParser): + group = parser.add_argument_group('model run config') + group.add_argument( + "--n_keep", + type=int, + help="Number of tokens to keep from the initial prompt: Int (default: 0; -1 = all)", + default=1, + ) + group.add_argument( + "--shift_roped_k", + action="store_true", + help="Use ring-buffer and thus do not re-computing after reaching ctx_size (default: False)", + ) + group.add_argument("--memory_dtype", + type=str, + help="Data type of the kv memory", + choices=['f32', 'f16', 'auto'], + default="auto") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluate perplexity for a model givan a dataset") + parser.add_argument('--model_name', type=str, default="~/Llama-2-7b-chat-hf") + parser.add_argument('--dataset_name', type=str, default="~/pg19-data-test") + parser.add_argument('--ctx_size', type=int, default=256) + parser.add_argument('--prompt_size', type=int) + parser.add_argument('--n_threads', type=int) + parser.add_argument('--n_pred_per_sample', type=int) + parser.add_argument('--n_sampels', type=int) + parser.add_argument('--data_text_concat', action="store_true") + parser.add_argument('--fig_path', type=str, default="out/ppl.png") + add_quant_args(parser) + add_run_args(parser) + + ns_args = parser.parse_args() + args = vars(ns_args) + args = {k: args[k] for k in args if args[k] is not None} + + pathlib.Path.mkdir(pathlib.Path("out"), exist_ok=True) + ppl_data = perplexity(**args) + + # draw the graph + job_name = f"{ns_args.model_name}-{ns_args.weight_dtype}" + if ns_args.weight_dtype != 'fp32': + job_name += f"-{ns_args.compute_dtype}-g{ns_args.group_size}" + + job_name += f"-keep{ns_args.n_keep}" + draw_ppl(ns_args.fig_path, ppl_data, ns_args.ctx_size, job_name) + + # dump raw data + import json + with open('out/ppl_data.json', 'w') as f: + json.dump(ppl_data, f, indent=2) From dbec7a48053c52e2e456842c62544af9b10a3714 Mon Sep 17 00:00:00 2001 From: "Ding, Yi1" Date: Wed, 15 Nov 2023 19:13:19 +0800 Subject: [PATCH 02/11] fix --- .../llm/runtime/graph/__init__.py | 71 ++++++++++++------- .../runtime/graph/application/main_pybind.cpp | 36 +++++----- .../llm/runtime/graph/scripts/perplexity.py | 38 ++++++---- tests/test_llm_runtime.py | 3 + 4 files changed, 90 insertions(+), 58 deletions(-) diff --git a/intel_extension_for_transformers/llm/runtime/graph/__init__.py b/intel_extension_for_transformers/llm/runtime/graph/__init__.py index 9433ed7054c..fb3fdb63037 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/__init__.py +++ b/intel_extension_for_transformers/llm/runtime/graph/__init__.py @@ -31,54 +31,68 @@ def __init__(self): self.bin_file = None self.generate_round = 0 - def __import_package(self, model_name): + def __import_package(self, model_type): if self.module: return - if model_name == "gptj": + if model_type == "gptj": import intel_extension_for_transformers.llm.runtime.graph.gptj_cpp as cpp_model - elif model_name == "falcon": + elif model_type == "falcon": import intel_extension_for_transformers.llm.runtime.graph.falcon_cpp as cpp_model - elif model_name == "gptneox": + elif model_type == "gptneox": import intel_extension_for_transformers.llm.runtime.graph.gptneox_cpp as cpp_model - elif model_name == "dolly": + elif model_type == "dolly": import intel_extension_for_transformers.llm.runtime.graph.dolly_cpp as cpp_model - elif model_name == "llama" or model_name == "llama2": + elif model_type == "llama" or model_type == "llama2": import intel_extension_for_transformers.llm.runtime.graph.llama_cpp as cpp_model - elif model_name == "mpt": + elif model_type == "mpt": import intel_extension_for_transformers.llm.runtime.graph.mpt_cpp as cpp_model - elif model_name == "gpt_bigcode" or model_name == "starcoder": + elif model_type == "gpt_bigcode" or model_type == "starcoder": import intel_extension_for_transformers.llm.runtime.graph.starcoder_cpp as cpp_model - elif model_name == "opt": + elif model_type == "opt": import intel_extension_for_transformers.llm.runtime.graph.opt_cpp as cpp_model - elif model_name == "bloom": + elif model_type == "bloom": import intel_extension_for_transformers.llm.runtime.graph.bloom_cpp as cpp_model - elif model_name == "chatglm": + elif model_type == "chatglm": import intel_extension_for_transformers.llm.runtime.graph.chatglm_cpp as cpp_model - elif model_name == "chatglm2": + elif model_type == "chatglm2": import intel_extension_for_transformers.llm.runtime.graph.chatglm2_cpp as cpp_model - elif model_name == "baichuan": + elif model_type == "baichuan": import intel_extension_for_transformers.llm.runtime.graph.baichuan_cpp as cpp_model - elif model_name == "polyglot": + elif model_type == "polyglot": import intel_extension_for_transformers.llm.runtime.graph.polyglot_cpp as cpp_model - elif model_name == "mistral": + elif model_type == "mistral": import intel_extension_for_transformers.llm.runtime.graph.mistral_cpp as cpp_model else: - raise TypeError("Unspported model type {}!".format(model_name)) + raise TypeError("Unspported model type {}!".format(model_type)) self.module = cpp_model + @staticmethod + def get_model_type(model_config): + model_type = model_maps.get(model_config.model_type, model_config.model_type) + if model_type == "chatglm" and "chatglm2" in model_config._name_or_path: + model_type = "chatglm2" + return model_type + def init(self, model_name, not_quant=False, use_cache=False, **quant_kwargs): self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - model_type = model_maps.get(self.config.model_type, self.config.model_type) - if model_type == "chatglm" and "chatglm2" in self.config._name_or_path: - model_type = "chatglm2" + model_type = Model.get_model_type(self.config) self.__import_package(model_type) # check cache and quantization output_path = "runtime_outs" os.makedirs(output_path, exist_ok=True) fp32_bin = "{}/ne_{}_f32.bin".format(output_path, model_type) - quant_bin = "{}/ne_{}_q.bin".format(output_path, model_type) + quant_desc = quant_kwargs['weight_dtype'] + if quant_kwargs['use_ggml']: + quant_desc += "_ggml" + else: + quant_desc += "_jblas_c" + quant_kwargs['compute_dtype'] + if quant_kwargs['group_size'] == -1: + quant_desc += "_pc" + else: + quant_desc += "_g{}".format(quant_kwargs['group_size']) + quant_bin = "{}/ne_{}_q_{}.bin".format(output_path, model_type, quant_desc) if not_quant: self.bin_file = fp32_bin @@ -87,19 +101,22 @@ def init(self, model_name, not_quant=False, use_cache=False, **quant_kwargs): if use_cache and os.path.exists(self.bin_file): return - convert_model(model_name, fp32_bin, "f32") - assert os.path.exists(fp32_bin), "Fail to convert pytorch model" + if not os.path.exists(fp32_bin): + convert_model(model_name, fp32_bin, "f32") + assert os.path.exists(fp32_bin), "Fail to convert pytorch model" if not_quant: print("FP32 model will be used.") return self.module.Model.quant_model(model_path=fp32_bin, out_path=quant_bin, **quant_kwargs) assert os.path.exists(quant_bin), "Fail to quantize model" + # clean - os.remove(fp32_bin) + if not use_cache: + os.remove(fp32_bin) - def init_from_bin(self, model_name, model_path, **generate_kwargs): - self.__import_package(model_name) + def init_from_bin(self, model_type, model_path, **generate_kwargs): + self.__import_package(model_type) self.model = self.module.Model() if "threads" not in generate_kwargs: threads = os.getenv("OMP_NUM_THREADS") @@ -109,8 +126,8 @@ def init_from_bin(self, model_name, model_path, **generate_kwargs): generate_kwargs["threads"] = int(threads) self.model.init_model(model_path, **generate_kwargs) - def quant_model(self, model_name, model_path, out_path, **quant_kwargs): - self.__import_package(model_name) + def quant_model(self, model_type, model_path, out_path, **quant_kwargs): + self.__import_package(model_type) self.module.Model.quant_model(model_path=model_path, out_path=out_path, **quant_kwargs) def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=False, stopping_criteria=None, **generate_kwargs): diff --git a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp index 294841596cc..2def96ddbeb 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp @@ -16,23 +16,26 @@ #define _GNU_SOURCE #endif +#include +#include +#include #include + +#include #include #include #include #include -#include -#include -#include #include +#include #include #include #include -#include -#include +#include + #include "common.h" -#include "models/model_utils/model_types.h" #include "models/model_utils/model_config.h" +#include "models/model_utils/model_types.h" #include "models/model_utils/model_utils.h" #if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) @@ -41,10 +44,12 @@ #elif defined(_WIN32) #define WIN32_LEAN_AND_MEAN #define NOMINMAX -#include #include +#include #endif +namespace py = pybind11; + std::shared_ptr get_model_quant_layer(const std::string model_name) { return ql_registry::create_ql(model_name); } @@ -63,13 +68,14 @@ class Model { std::vector> generate(const std::vector>& input_ids); std::vector> generate_tokens(const std::vector>& input_ids); const std::vector& evaluate_(const std::vector>& input_ids); - std::vector> evaluate(const std::vector>& input_ids) { + py::array_t evaluate(const std::vector>& input_ids) { if (input_ids.size() != 1) { fprintf(stderr, "\nERROR: only support batch == 1 input!\n"); - return {{}}; + return py::array_t(); } const auto& logits = evaluate_(input_ids); - return {logits}; + return py::array_t(logits.size(), logits.data()) + .reshape({py::ssize_t(-1), static_cast(ctx->model.hparams.n_vocab)}); } bool is_token_end() { return token_eos; } static int quant_model(const std::string& model_path, const std::string& out_path, const std::string& weight_dtype, @@ -170,7 +176,7 @@ std::vector> Model::beam_generate(const std::vectorbatch_size); if (ctx->batch_size > 1 && ctx->vocab.pad_token_id == -1) { fprintf(stderr, "\nERROR: please set pad_token for beam search multi-batch generation!\n"); - return {{}}; + return {}; } std::vector inputs; for (int bs = 0; bs < input_ids.size(); ++bs) { @@ -263,11 +269,11 @@ std::vector> Model::generate(const std::vectorbeam_search) return beam_generate(input_ids); if (input_ids.size() > 1) { fprintf(stderr, "\nERROR: Only beam search supports multi-batch generation!\n"); - return {{}}; + return {}; } const auto& logits = evaluate_(input_ids); - if (logits.empty()) return {{}}; + if (logits.empty()) return {}; model_token next_token_id = post_process(logits.data()); curr_input_ids = {next_token_id}; @@ -473,8 +479,6 @@ int Model::quant_model(const std::string& model_path, const std::string& out_pat return 0; } -namespace py = pybind11; - #if MODEL_NAME_ID == 1 PYBIND11_MODULE(gptj_cpp, m) @@ -544,7 +548,7 @@ PYBIND11_MODULE(mistral_cpp, m) py::arg("batch_size") = 1, py::arg("pad_token") = -1, py::arg("memory_dtype") = "auto") .def("generate", &Model::generate, "Generate token with input ids", py::arg("input_ids")) .def("evaluate", &Model::evaluate, "Evaluate token with input ids and output logits", - py::arg("input_ids") = std::vector{}) + py::arg("input_ids") = std::vector>{}) .def("generate_tokens", &Model::generate_tokens, "Generate tokens with input ids", py::arg("input_ids")) .def_static("quant_model", &Model::quant_model, "Quantize model", py::arg("model_path"), py::arg("out_path"), py::arg("weight_dtype") = "int4", py::arg("alg") = "sym", py::arg("group_size") = 32, diff --git a/intel_extension_for_transformers/llm/runtime/graph/scripts/perplexity.py b/intel_extension_for_transformers/llm/runtime/graph/scripts/perplexity.py index d3ef5519134..724de58fd11 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/scripts/perplexity.py +++ b/intel_extension_for_transformers/llm/runtime/graph/scripts/perplexity.py @@ -69,7 +69,7 @@ def perplexity(model_name, dataset_name, **kwargs): n_pred_per_sample = kwargs.get("n_pred_per_sample", ctx_size * 2) n_sampels = kwargs.get("n_sampels", 2) data_text_concat = kwargs.get("data_text_concat", "wikitext-2-raw-v1" in dataset_name) # concat samples with `\n\n` - default_model_kwargs = {"batch_size": 256, "ctx_size": 256, "n_keep": 4} + default_model_kwargs = {"n_batch": 256, "ctx_size": 256, "n_keep": 4} tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) data = datasets.load_from_disk(dataset_name) @@ -94,18 +94,23 @@ def perplexity(model_name, dataset_name, **kwargs): test_ids.append(ids) pbar.update(1) - del tokenizer - - vocab_size: int = AutoConfig.from_pretrained(model_name, trust_remote_code=True).vocab_size - - woq_kwargs = {k: kwargs[k] for k in kwargs if k in - ['use_cache', 'compute_dtype', 'weight_dtype', 'scale_dtype', 'group_size', 'use_ggml']} - model_kwargs = {k: kwargs[k] for k in kwargs if k in - ['n_keep', 'shift_roped_k', 'memory_dtype']} + quantized_weight_path = kwargs.pop('quantized_weight_path', None) + if quantized_weight_path: + from intel_extension_for_transformers.llm.runtime.graph import Model + model = Model() + assert pathlib.Path(quantized_weight_path).is_file(), "Quantized weight not exist!" + model.bin_file = quantized_weight_path + model.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + model.model_type = Model.get_model_type(model.config) + model.tokenizer = tokenizer + else: + woq_kwargs = {k: kwargs[k] for k in kwargs if k in + ['use_cache', 'compute_dtype', 'weight_dtype', 'scale_dtype', 'group_size', 'use_ggml']} + woq_config = WeightOnlyQuantConfig(**woq_kwargs) + model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True) + + model_kwargs = {k: kwargs[k] for k in kwargs if k in ['n_keep', 'shift_roped_k', 'memory_dtype']} model_kwargs = {**default_model_kwargs, **model_kwargs} - woq_config = WeightOnlyQuantConfig(**woq_kwargs) - model = AutoModelForCausalLM.from_pretrained( - model_name, quantization_config=woq_config, trust_remote_code=True, **model_kwargs) ppl_hist = [{} for _ in range(n_sampels)] # ppl_hist[i_sample][end_pos] = ppl sum_nll = [0. for _ in range(n_sampels)] # sum of negative log likelyhood @@ -122,9 +127,8 @@ def perplexity(model_name, dataset_name, **kwargs): end_pos = i_pred + prompt_size cur_input = test_ids[i_sample][:, begin_pos:end_pos] cur_target: torch.Tensor = test_ids[i_sample][:, end_pos] - out = model(cur_input, threads=n_threads, reinit=is_first) - out = torch.tensor(out).reshape(-1, vocab_size) - logsoftmax = out.log_softmax(-1) + out = model(cur_input, threads=n_threads, reinit=is_first, **model_kwargs) + logsoftmax = torch.from_numpy(out).log_softmax(-1) nll = logsoftmax.take_along_dim(cur_target.view(-1, 1), 1) assert len(nll) == 1 nll_v = -nll.flatten().tolist()[0] @@ -178,6 +182,10 @@ def draw_ppl(img_path: str, ppl_data: List[Dict[int, float]], ctx_size: int, mod def add_quant_args(parser: argparse.ArgumentParser): group = parser.add_argument_group('quantize config') + group.add_argument('--quantized_weight_path', + type=str, + help="path to quantized weight; other quant args will be ignored if specified", + default="") group.add_argument('--use_cache', action="store_true", help="Use local quantized model if file exists") diff --git a/tests/test_llm_runtime.py b/tests/test_llm_runtime.py index 75f904fff36..45a6699b11f 100644 --- a/tests/test_llm_runtime.py +++ b/tests/test_llm_runtime.py @@ -67,3 +67,6 @@ def test_beam_search(self): pad_token=pad_token) for i in range(len(itrex_generate_ids)): self.assertListEqual(pt_generate_ids[i], itrex_generate_ids[i]) + +if __name__ == "__main__": + unittest.main() From 73221201564b068f1eab7521af8abf2c567fb08c Mon Sep 17 00:00:00 2001 From: "Ding, Yi1" Date: Wed, 15 Nov 2023 22:26:15 +0800 Subject: [PATCH 03/11] fit --- .../llm/runtime/graph/scripts/perplexity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/intel_extension_for_transformers/llm/runtime/graph/scripts/perplexity.py b/intel_extension_for_transformers/llm/runtime/graph/scripts/perplexity.py index 724de58fd11..5336352cbb1 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/scripts/perplexity.py +++ b/intel_extension_for_transformers/llm/runtime/graph/scripts/perplexity.py @@ -69,7 +69,7 @@ def perplexity(model_name, dataset_name, **kwargs): n_pred_per_sample = kwargs.get("n_pred_per_sample", ctx_size * 2) n_sampels = kwargs.get("n_sampels", 2) data_text_concat = kwargs.get("data_text_concat", "wikitext-2-raw-v1" in dataset_name) # concat samples with `\n\n` - default_model_kwargs = {"n_batch": 256, "ctx_size": 256, "n_keep": 4} + default_model_kwargs = {"n_batch": ctx_size, "ctx_size": ctx_size, "n_keep": 4} tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) data = datasets.load_from_disk(dataset_name) From 9c619b590a5e084cf5b4db86598c591ba7f4111e Mon Sep 17 00:00:00 2001 From: zhenwei-intel Date: Thu, 16 Nov 2023 09:27:44 +0800 Subject: [PATCH 04/11] update test Signed-off-by: zhenwei-intel --- .../llm/runtime/graph/__init__.py | 4 +- tests/test_llm_runtime.py | 40 ++++++++++++------- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/intel_extension_for_transformers/llm/runtime/graph/__init__.py b/intel_extension_for_transformers/llm/runtime/graph/__init__.py index fb3fdb63037..0835b5d9ed6 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/__init__.py +++ b/intel_extension_for_transformers/llm/runtime/graph/__init__.py @@ -101,7 +101,7 @@ def init(self, model_name, not_quant=False, use_cache=False, **quant_kwargs): if use_cache and os.path.exists(self.bin_file): return - if not os.path.exists(fp32_bin): + if not use_cache or not os.path.exists(fp32_bin): convert_model(model_name, fp32_bin, "f32") assert os.path.exists(fp32_bin), "Fail to convert pytorch model" @@ -171,6 +171,8 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa streamer.put(torch.tensor([response[0]])) for i in range(len(response)): ret[i].extend(response[i]) + if beam_search: + break if stopping_criteria is not None: if stopping_criteria(torch.tensor(ret), None): break diff --git a/tests/test_llm_runtime.py b/tests/test_llm_runtime.py index 45a6699b11f..fe0ef40c1ab 100644 --- a/tests/test_llm_runtime.py +++ b/tests/test_llm_runtime.py @@ -1,4 +1,4 @@ -import numpy +import numpy as np import shutil import torch import unittest @@ -8,6 +8,13 @@ from intel_extension_for_transformers.llm.runtime.graph.scripts.convert import convert_model from intel_extension_for_transformers.llm.runtime.graph import Model +def cmpData(numa, numb): + totalErr = ((np.abs(numa - numb))**2).sum() + totalNum = (np.abs(numa)**2).sum() + diff2 = np.sqrt(totalErr/totalNum) + + cos = np.dot(numa, numb)/(np.linalg.norm(numa)*np.linalg.norm(numb)) + return {"diff2": diff2, "cos": cos} class TestLLMRUNTIME(unittest.TestCase): @@ -17,24 +24,29 @@ def setUpClass(cls): @classmethod def tearDownClass(cls) -> None: - shutil.rmtree("./ne_chatglm_q.bin", ignore_errors=True) - shutil.rmtree("./gptj_fp32.bin", ignore_errors=True) + shutil.rmtree("./runtime_outs", ignore_errors=True) def test_llm_runtime(self): - - model_name = "/tf_dataset2/models/pytorch/chatglm2-6b" # or local path to model - woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4") - prompt = "小明的妈妈有三个孩子,老大叫大毛,老二叫二毛,老三叫什么?" + model_name = "/tf_dataset2/models/pytorch/Llama-2-7b-chat-hf" # or local path to model + woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4", use_cache=True, not_quant=True) + prompt = "What is the meaning of life?" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - input_ids = tokenizer(prompt, return_tensors="pt").input_ids - streamer = TextStreamer(tokenizer) + inputs = tokenizer(prompt, return_tensors="pt") - model = AutoModel.from_pretrained(model_name, quantization_config=woq_config, use_llm_runtime=True, trust_remote_code=True) - gen_tokens = model.generate(input_ids, streamer=streamer, max_new_tokens=300, seed=1) - outputs = tokenizer.batch_decode(gen_tokens) - print(outputs) - self.assertTrue("小明" in outputs[0]) + # pytorch fp32 + pt_model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) + pt_model.eval() + logits = pt_model(input_ids=inputs.input_ids).logits[:,-1] + pt_generate_ids = pt_model.generate(input_ids=inputs.input_ids, max_new_tokens=128).tolist() + + itrex_model = AutoModel.from_pretrained(model_name, quantization_config=woq_config, use_llm_runtime=True, trust_remote_code=True) + outputs = itrex_model(inputs.input_ids) + itrex_generate_ids = itrex_model.generate(inputs.input_ids, max_new_tokens=128) + print(cmpData(logits.detach().numpy().flatten(), outputs.flatten())) + + for i in range(len(itrex_generate_ids)): + self.assertListEqual(pt_generate_ids[i], itrex_generate_ids[i]) def test_beam_search(self): model_name = "/tf_dataset2/models/pytorch/gpt-j-6B" # or local path to model From cd8f7d2277570eee62e072cee8c9d6381cea1804 Mon Sep 17 00:00:00 2001 From: "Ding, Yi1" Date: Thu, 16 Nov 2023 10:33:38 +0800 Subject: [PATCH 05/11] fix empty batch input --- .../llm/runtime/graph/application/main_pybind.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp index 2def96ddbeb..56ea0e2874e 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp @@ -208,8 +208,9 @@ const std::vector& Model::evaluate_(const std::vector empty_id{}; + const auto& input_id0 = input_ids.empty() ? empty_id : input_ids[0]; // currently only support single batch + if (input_id0.empty()) { // use internel input id if (curr_input_ids.empty()) { fprintf(stderr, "%s: error: no input\n", __func__); return empty_ret; From dcd7ca6e8e6a0a4bdfbbbf27f45790400bdea9c9 Mon Sep 17 00:00:00 2001 From: zhenwei-intel Date: Thu, 16 Nov 2023 11:27:15 +0800 Subject: [PATCH 06/11] update runtime test Signed-off-by: zhenwei-intel --- tests/test_llm_runtime.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/test_llm_runtime.py b/tests/test_llm_runtime.py index fe0ef40c1ab..378ee93acaa 100644 --- a/tests/test_llm_runtime.py +++ b/tests/test_llm_runtime.py @@ -27,7 +27,7 @@ def tearDownClass(cls) -> None: shutil.rmtree("./runtime_outs", ignore_errors=True) def test_llm_runtime(self): - model_name = "/tf_dataset2/models/pytorch/Llama-2-7b-chat-hf" # or local path to model + model_name = "/tf_dataset2/models/nlp_toolkit/llama-2-7b-chat/Llama-2-7b-chat-hf" woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4", use_cache=True, not_quant=True) prompt = "What is the meaning of life?" @@ -38,15 +38,17 @@ def test_llm_runtime(self): pt_model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) pt_model.eval() logits = pt_model(input_ids=inputs.input_ids).logits[:,-1] - pt_generate_ids = pt_model.generate(input_ids=inputs.input_ids, max_new_tokens=128).tolist() - + pt_generate_ids = pt_model.generate(input_ids=inputs.input_ids, do_sample=False, max_new_tokens=100)[0].tolist() + print(tokenizer.decode(pt_generate_ids)) + itrex_model = AutoModel.from_pretrained(model_name, quantization_config=woq_config, use_llm_runtime=True, trust_remote_code=True) - outputs = itrex_model(inputs.input_ids) - itrex_generate_ids = itrex_model.generate(inputs.input_ids, max_new_tokens=128) + outputs = itrex_model.forward(inputs.input_ids) + itrex_generate_ids = itrex_model.generate(inputs.input_ids, do_sample=False, max_new_tokens=100)[0] + print(tokenizer.decode(itrex_generate_ids)) print(cmpData(logits.detach().numpy().flatten(), outputs.flatten())) - for i in range(len(itrex_generate_ids)): - self.assertListEqual(pt_generate_ids[i], itrex_generate_ids[i]) + for i in range(len(pt_generate_ids)): + self.assertEqual(pt_generate_ids[i], itrex_generate_ids[i]) def test_beam_search(self): model_name = "/tf_dataset2/models/pytorch/gpt-j-6B" # or local path to model From a93aa4511243e34eb38b9a162846e72aba982691 Mon Sep 17 00:00:00 2001 From: zhenwei-intel Date: Thu, 16 Nov 2023 11:50:39 +0800 Subject: [PATCH 07/11] add script for diff test Signed-off-by: zhenwei-intel --- .../llm/runtime/graph/scripts/cal_diff.py | 42 +++++++++++++++++++ tests/test_llm_runtime.py | 6 +-- 2 files changed, 45 insertions(+), 3 deletions(-) create mode 100644 intel_extension_for_transformers/llm/runtime/graph/scripts/cal_diff.py diff --git a/intel_extension_for_transformers/llm/runtime/graph/scripts/cal_diff.py b/intel_extension_for_transformers/llm/runtime/graph/scripts/cal_diff.py new file mode 100644 index 00000000000..39d67521d1d --- /dev/null +++ b/intel_extension_for_transformers/llm/runtime/graph/scripts/cal_diff.py @@ -0,0 +1,42 @@ +import numpy as np +import argparse +from transformers import AutoTokenizer, TextStreamer +from intel_extension_for_transformers.transformers import AutoModel, WeightOnlyQuantConfig, AutoModelForCausalLM + + +def cmpData(numa, numb): + totalErr = ((np.abs(numa - numb))**2).sum() + totalNum = (np.abs(numa)**2).sum() + diff2 = np.sqrt(totalErr/totalNum) + + cos = np.dot(numa, numb)/(np.linalg.norm(numa)*np.linalg.norm(numb)) + return {"diff2": diff2, "cos": cos} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluate diff for a model") + parser.add_argument('--model_name', type=str, default="~/Llama-2-7b-chat-hf") + args = parser.parse_args() + + woq_configs = { + "fp32": WeightOnlyQuantConfig(use_cache=True, not_quant=True), + "ggml_int4": WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4", use_cache=True, use_ggml=True), + "jblas_int4": WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4", use_cache=True), + "jblas_int8": WeightOnlyQuantConfig(compute_dtype="bf16", weight_dtype="int8", use_cache=True), + } + prompt = "What is the meaning of life?" + + model_name = args.model_name + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + inputs = tokenizer(prompt, return_tensors="pt") + + pt_model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) + pt_model.eval() + pt_logits = pt_model(input_ids=inputs.input_ids).logits[:,-1] + + for config_type in woq_configs: + itrex_model = AutoModel.from_pretrained(model_name, quantization_config=woq_configs[config_type], + use_llm_runtime=True, trust_remote_code=True) + itrex_logits = itrex_model(inputs.input_ids) + + print(config_type, cmpData(pt_logits.detach().numpy().flatten(), itrex_logits.flatten())) diff --git a/tests/test_llm_runtime.py b/tests/test_llm_runtime.py index 378ee93acaa..5c049f5004e 100644 --- a/tests/test_llm_runtime.py +++ b/tests/test_llm_runtime.py @@ -37,15 +37,15 @@ def test_llm_runtime(self): # pytorch fp32 pt_model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) pt_model.eval() - logits = pt_model(input_ids=inputs.input_ids).logits[:,-1] + pt_logits = pt_model(input_ids=inputs.input_ids).logits[:,-1] pt_generate_ids = pt_model.generate(input_ids=inputs.input_ids, do_sample=False, max_new_tokens=100)[0].tolist() print(tokenizer.decode(pt_generate_ids)) itrex_model = AutoModel.from_pretrained(model_name, quantization_config=woq_config, use_llm_runtime=True, trust_remote_code=True) - outputs = itrex_model.forward(inputs.input_ids) + itrex_outputs = itrex_model(inputs.input_ids) itrex_generate_ids = itrex_model.generate(inputs.input_ids, do_sample=False, max_new_tokens=100)[0] print(tokenizer.decode(itrex_generate_ids)) - print(cmpData(logits.detach().numpy().flatten(), outputs.flatten())) + print(cmpData(pt_logits.detach().numpy().flatten(), itrex_outputs.flatten())) for i in range(len(pt_generate_ids)): self.assertEqual(pt_generate_ids[i], itrex_generate_ids[i]) From d3a8866c36ed8ca42330a56f3672d122da4a9c95 Mon Sep 17 00:00:00 2001 From: zhenwei-intel Date: Thu, 16 Nov 2023 13:16:06 +0800 Subject: [PATCH 08/11] add copyright Signed-off-by: zhenwei-intel --- .../llm/runtime/graph/scripts/cal_diff.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/intel_extension_for_transformers/llm/runtime/graph/scripts/cal_diff.py b/intel_extension_for_transformers/llm/runtime/graph/scripts/cal_diff.py index 39d67521d1d..be47c9f199e 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/scripts/cal_diff.py +++ b/intel_extension_for_transformers/llm/runtime/graph/scripts/cal_diff.py @@ -1,3 +1,20 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import numpy as np import argparse from transformers import AutoTokenizer, TextStreamer @@ -5,8 +22,8 @@ def cmpData(numa, numb): - totalErr = ((np.abs(numa - numb))**2).sum() - totalNum = (np.abs(numa)**2).sum() + totalErr = ((numa - numb)**2).sum() + totalNum = (numa**2).sum() diff2 = np.sqrt(totalErr/totalNum) cos = np.dot(numa, numb)/(np.linalg.norm(numa)*np.linalg.norm(numb)) From ca2d980bf2ba1a0a7907023f557d4fe36c31d493 Mon Sep 17 00:00:00 2001 From: "Ding, Yi1" Date: Thu, 16 Nov 2023 20:51:21 +0800 Subject: [PATCH 09/11] Add Perplexity script in README --- intel_extension_for_transformers/llm/runtime/graph/README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/intel_extension_for_transformers/llm/runtime/graph/README.md b/intel_extension_for_transformers/llm/runtime/graph/README.md index 45820ff5129..a9905f2415a 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/README.md +++ b/intel_extension_for_transformers/llm/runtime/graph/README.md @@ -347,7 +347,7 @@ class StopOnTokens(StoppingCriteria): self.min_length = min_length self.start_length = start_length self.stop_token_id = stop_token_id - + def __call__( self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs ) -> bool: @@ -369,3 +369,6 @@ stopping_criteria = StoppingCriteriaList( outputs = model.generate(inputs, streamer=streamer, stopping_criteria=stopping_criteria) ``` + +### 6. Perplexity (measuring model quality) +You can use the [scripts/perplexity.py](./scripts/perplexity.py) script to over a given (subset of) dataset. Run `python scripts/perplexity.py --help` for detailed usage. For more infomation of the perplexity metric, see https://huggingface.co/docs/transformers/perplexity. From 436897cc8f922d787f1b54327c3bc8b6e7e8b341 Mon Sep 17 00:00:00 2001 From: "Ding, Yi1" Date: Thu, 16 Nov 2023 20:54:44 +0800 Subject: [PATCH 10/11] fix 'batch_size=4' --- tests/test_llm_runtime.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_llm_runtime.py b/tests/test_llm_runtime.py index 5c049f5004e..f2d8858a39e 100644 --- a/tests/test_llm_runtime.py +++ b/tests/test_llm_runtime.py @@ -75,12 +75,13 @@ def test_beam_search(self): early_stopping=True, num_beams=4).tolist() # llm runtime fp32 woq_config = WeightOnlyQuantConfig(not_quant=True) - itrex_model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True) - itrex_generate_ids = itrex_model.generate(inputs.input_ids, batch_size=4, num_beams=4, - max_new_tokens=128, min_new_tokens=30, early_stopping=True, - pad_token=pad_token) + itrex_model = AutoModelForCausalLM.from_pretrained( + model_name, quantization_config=woq_config, trust_remote_code=True) + itrex_generate_ids = itrex_model.generate( + inputs.input_ids, num_beams=4, max_new_tokens=128, min_new_tokens=30, early_stopping=True, pad_token=pad_token) for i in range(len(itrex_generate_ids)): self.assertListEqual(pt_generate_ids[i], itrex_generate_ids[i]) + if __name__ == "__main__": unittest.main() From e6c924e5fab661d3cb5ba6a54b866dc966f9b9b5 Mon Sep 17 00:00:00 2001 From: VincyZhang Date: Thu, 16 Nov 2023 20:58:20 +0800 Subject: [PATCH 11/11] Update requirements.txt Signed-off-by: VincyZhang --- tests/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements.txt b/tests/requirements.txt index 04cabc872ef..602d9aeca71 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -11,7 +11,7 @@ transformers==4.33 intel-tensorflow==2.13.0 torchprofile intel-extension-for-pytorch -tokenizers<=0.12.1 +tokenizers sentencepiece != 0.1.92 accelerate evaluate