@@ -168,6 +168,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
168168 break ;
169169 }
170170 params.n_ctx = std::stoi (argv[i]);
171+ } else if (arg == " -gqa" || arg == " --gqa" ) {
172+ if (++i >= argc) {
173+ invalid_param = true ;
174+ break ;
175+ }
176+ params.n_gqa = std::stoi (argv[i]);
171177 } else if (arg == " --rope-freq-base" ) {
172178 if (++i >= argc) {
173179 invalid_param = true ;
@@ -485,6 +491,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
485491 fprintf (stdout, " -f FNAME, --file FNAME\n " );
486492 fprintf (stdout, " prompt file to start generation.\n " );
487493 fprintf (stdout, " -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity)\n " , params.n_predict );
494+ fprintf (stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n " , params.n_ctx );
495+ fprintf (stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n " , params.n_batch );
496+ fprintf (stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n " , params.n_gqa );
488497 fprintf (stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n " , params.top_k );
489498 fprintf (stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n " , (double )params.top_p );
490499 fprintf (stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n " , (double )params.tfs_z );
@@ -505,15 +514,13 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
505514 fprintf (stdout, " --cfg-negative-prompt PROMPT \n " );
506515 fprintf (stdout, " negative prompt to use for guidance. (default: empty)\n " );
507516 fprintf (stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n " , params.cfg_scale );
508- fprintf (stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n " , params.n_ctx );
509517 fprintf (stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n " , params.rope_freq_base );
510518 fprintf (stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n " , params.rope_freq_scale );
511519 fprintf (stdout, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n " );
512520 fprintf (stdout, " --no-penalize-nl do not penalize newline token\n " );
513521 fprintf (stdout, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n " );
514522 fprintf (stdout, " not recommended: doubles context memory required and no measurable increase in quality\n " );
515523 fprintf (stdout, " --temp N temperature (default: %.1f)\n " , (double )params.temp );
516- fprintf (stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n " , params.n_batch );
517524 fprintf (stdout, " --perplexity compute perplexity over each ctx window of the prompt\n " );
518525 fprintf (stdout, " --perplexity-lines compute perplexity over each line of the prompt\n " );
519526 fprintf (stdout, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n " , params.n_keep );
@@ -580,6 +587,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
580587
581588 lparams.n_ctx = params.n_ctx ;
582589 lparams.n_batch = params.n_batch ;
590+ lparams.n_gqa = params.n_gqa ;
583591 lparams.n_gpu_layers = params.n_gpu_layers ;
584592 lparams.main_gpu = params.main_gpu ;
585593 lparams.tensor_split = params.tensor_split ;
0 commit comments