Skip to content

Commit

Permalink
[RUNTIME] Enabing streaming llm for Runtime (#501)
Browse files Browse the repository at this point in the history
* Support StreamingLLM on CPU

Signed-off-by: zhenwei-intel <zhenwei.liu@intel.com>
  • Loading branch information
zhenwei-intel authored and VincyZhang committed Oct 23, 2023
1 parent e1f9e2b commit ffc73bb
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from transformers import AutoConfig
from intel_extension_for_transformers.llm.runtime.graph.scripts.convert import convert_model
import torch
model_maps = {"gpt_neox": "gptneox"}
model_maps = {"gpt_neox": "gptneox", "gpt_bigcode": "starcoder"}

class Model:
def __init__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Model {
}
void init_model(const std::string& model_path, int n_predict, int batch_size, int ctx_size, int seed, int threads,
float repetition_penalty, int num_beams, bool do_sample, int top_k, float top_p, float temperature,
int min_new_tokens, float length_penalty, bool early_stopping);
int min_new_tokens, float length_penalty, bool early_stopping, int n_keep, int n_discard);
void reinit();
std::vector<model_token> generate(const std::vector<model_token>& input_ids);
std::vector<model_token> generate_tokens(const std::vector<model_token>& input_ids);
Expand Down Expand Up @@ -85,7 +85,8 @@ class Model {

void Model::init_model(const std::string& model_path, int max_new_tokens, int batch_size, int ctx_size, int seed,
int threads, float repetition_penalty, int num_beams, bool do_sample, int top_k, float top_p,
float temperature, int min_new_tokens, float length_penalty, bool early_stopping) {
float temperature, int min_new_tokens, float length_penalty, bool early_stopping, int n_keep,
int n_discard) {
#ifdef MODEL_NAME
params.model_name = MODEL_NAME;
#endif
Expand All @@ -106,6 +107,8 @@ void Model::init_model(const std::string& model_path, int max_new_tokens, int ba
params.top_k = top_k;
params.top_p = top_p;
params.temp = temperature;
params.n_keep = n_keep;
params.n_discard = n_discard;

printf("beam_size: %d, do_sample: %d, top_k: %d, top_p: %f\n", params.beam_size, params.do_sample, params.top_k,
params.top_p);
Expand Down Expand Up @@ -141,17 +144,14 @@ std::vector<model_token> Model::generate(const std::vector<model_token>& input_i
last_n_tokens.push_back(item);
}
// infinite text generation via context swapping
// if we run out of context:
// - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
if (n_past + curr_input_ids.size() > n_ctx) {
const int n_left = n_past - params.n_keep;

// always keep the first token - BOS
// always keep the first token
n_past = std::max(1, params.n_keep);

// insert n_left/2 tokens at the start of embd from last_n_tokens
curr_input_ids.insert(curr_input_ids.begin(), last_n_tokens.begin() + n_ctx - n_left / 2 - curr_input_ids.size(),
int n_discard = params.n_discard;
if (n_discard == -1) n_discard = (n_ctx - curr_input_ids.size() - params.n_keep) / 2;
// drop n_discard tokens
curr_input_ids.insert(curr_input_ids.begin(), last_n_tokens.begin() + params.n_keep + n_discard,
last_n_tokens.end() - curr_input_ids.size());
}
model_eval(ctx, &curr_input_ids[0], curr_input_ids.size(), n_past, params.n_threads);
Expand Down Expand Up @@ -182,17 +182,14 @@ std::vector<model_token> Model::generate_tokens(const std::vector<model_token>&
last_n_tokens.push_back(item);
}
// infinite text generation via context swapping
// if we run out of context:
// - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
if (n_past + curr_input_ids.size() > n_ctx) {
const int n_left = n_past - params.n_keep;

// always keep the first token - BOS
// always keep the first token
n_past = std::max(1, params.n_keep);

// insert n_left/2 tokens at the start of embd from last_n_tokens
curr_input_ids.insert(curr_input_ids.begin(), last_n_tokens.begin() + n_ctx - n_left / 2 - curr_input_ids.size(),
int n_discard = params.n_discard;
if (n_discard == -1) n_discard = (n_ctx - curr_input_ids.size() - params.n_keep) / 2;
// drop n_discard tokens
curr_input_ids.insert(curr_input_ids.begin(), last_n_tokens.begin() + params.n_keep + n_discard,
last_n_tokens.end() - curr_input_ids.size());
}
if (ctx->beam_search) {
Expand Down Expand Up @@ -374,7 +371,8 @@ PYBIND11_MODULE(polyglot_cpp, m)
py::arg("max_new_tokens") = -1, py::arg("batch_size") = 512, py::arg("ctx_size") = 512, py::arg("seed") = -1,
py::arg("threads") = 8, py::arg("repetition_penalty") = 1.1f, py::arg("num_beams") = 1,
py::arg("do_sample") = false, py::arg("top_k") = 40, py::arg("top_p") = 0.95, py::arg("temperature") = 0.8,
py::arg("min_new_tokens") = 0, py::arg("length_penalty") = 1.0, py::arg("early_stopping") = false)
py::arg("min_new_tokens") = 0, py::arg("length_penalty") = 1.0, py::arg("early_stopping") = false,
py::arg("n_keep") = 0, py::arg("n_discard") = -1)
.def("generate", &Model::generate, "Generate token with input ids", py::arg("input_ids"))
.def("generate_tokens", &Model::generate_tokens, "Generate tokens with input ids", py::arg("input_ids"))
.def_static("quant_model", &Model::quant_model, "Quantize model", py::arg("model_path"), py::arg("out_path"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,14 +388,13 @@ int main(int argc, char** argv) {
// - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
if (n_past + (int)embd.size() > n_ctx) {
const int n_left = n_past - params.n_keep;

// always keep the first token - BOS
// always keep the first token
n_past = std::max(1, params.n_keep);

// insert n_left/2 tokens at the start of embd from last_n_tokens
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left / 2 - embd.size(),
last_n_tokens.end() - embd.size());
int n_discard = params.n_discard;
if (n_discard == -1) n_discard = (n_ctx - embd.size() - params.n_keep) / 2;
// drop n_discard tokens
embd.insert(embd.begin(), last_n_tokens.begin() + params.n_keep + n_discard, last_n_tokens.end() - embd.size());

// stop saving session if we run out of context
path_session.clear();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,12 @@ bool gpt_params_parse(int argc, char** argv, gpt_params& params) {
break;
}
params.n_keep = std::stoi(argv[i]);
} else if (arg == "--n_discard") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.n_discard = std::stoi(argv[i]);
} else if (arg == "-m" || arg == "--model") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -455,6 +461,10 @@ void gpt_print_usage(int /*argc*/, char** argv, const gpt_params& params) {
fprintf(stderr, " --perplexity compute perplexity over the prompt\n");
fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n",
params.n_keep);
fprintf(stderr,
" --n_discard number of tokens will be discarded (default: %d, -1 = half of tokens will be "
"discarded)\n",
params.n_discard);
if (model_mlock_supported()) {
fprintf(stderr, " --mlock force system to keep model in RAM rather than swapping or compressing\n");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct gpt_params {
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = -1; // number of tokens to drop when reaching n_ctx
int32_t n_gpu_layers = 0; // number of layers to store in VRAM

// sampling parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from transformers import AutoConfig
import subprocess

model_maps = {"gpt_neox": "gptneox"}
model_maps = {"gpt_neox": "gptneox", "gpt_bigcode": "starcoder"}


def convert_model(model, outfile, outtype):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import subprocess
from transformers import AutoTokenizer

model_maps = {"gpt_neox": "gptneox", "llama2": "llama"}
model_maps = {"gpt_neox": "gptneox", "llama2": "llama", "gpt_bigcode": "starcoder"}
build_path = Path(Path(__file__).parent.absolute(), "../build/")

def main(args_in: Optional[List[str]] = None) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import List, Optional
import subprocess

model_maps = {"gpt_neox": "gptneox", "llama2": "llama"}
model_maps = {"gpt_neox": "gptneox", "llama2": "llama", "gpt_bigcode": "starcoder"}
build_path = Path(Path(__file__).parent.absolute(), "../build/")

def str2bool(v):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from transformers import AutoConfig
import subprocess

model_maps = {"gpt_neox": "gptneox"}
model_maps = {"gpt_neox": "gptneox", "gpt_bigcode": "starcoder"}
build_path = Path(Path(__file__).parent.absolute(), "../build/")

def str2bool(v):
Expand Down

0 comments on commit ffc73bb

Please sign in to comment.