Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : add grammar-based sampling #1773

Merged
merged 22 commits into from Jul 24, 2023
Merged

llama : add grammar-based sampling #1773

merged 22 commits into from Jul 24, 2023

Conversation

ejones
Copy link
Collaborator

@ejones ejones commented Jun 9, 2023

EDITED after updates

Inspired by #1397 and grantslatton's CFG work, this adds an API that takes a serialized context-free grammar to guide and constrain sampling. Also adds a sample Backus-Naur form (BNF)-like syntax in main for specifying a grammar for generations.

Testing

(M2 Max, 30B)

Chess
 % ./main -m $LLAMA_30B_Q4_0 -n 32 -p $'A good game:\n\n' --grammar-file grammars/chess.gbnf
main: build = 674 (e550234)
main: seed  = 1688014137
llama.cpp: loading model from /Users/evan/llama-models/30B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 6656
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 52
llama_model_load_internal: n_layer    = 60
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 17920
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 30B
llama_model_load_internal: ggml ctx size =    0.13 MB
llama_model_load_internal: mem required  = 19756.66 MB (+ 3124.00 MB per state)
.
llama_init_from_file: kv self size  =  780.00 MB

system_info: n_threads = 8 / 12 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 512, n_batch = 512, n_predict = 32, n_keep = 0


main: grammar:
root ::= [1] [.] [ ] move [ ] move [<U+000A>] root_4 
move ::= move_5 move_9 
root_2 ::= [1-9] root_3 [.] [ ] move [ ] move [<U+000A>] 
root_3 ::= [0-9] | 
root_4 ::= root_2 root_4 | root_2 
move_5 ::= pawn | nonpawn | castle 
pawn ::= pawn_14 [a-h] [1-8] pawn_16 
nonpawn ::= [NBKQR] nonpawn_10 nonpawn_11 nonpawn_12 [a-h] [1-8] 
castle ::= [O] [-] [O] castle_17 
move_9 ::= [+#] | 
nonpawn_10 ::= [a-h] | 
nonpawn_11 ::= [1-8] | 
nonpawn_12 ::= [x] | 
pawn_13 ::= [a-h] [x] 
pawn_14 ::= pawn_13 | 
pawn_15 ::= [=] [NBKQR] 
pawn_16 ::= pawn_15 | 
castle_17 ::= [-] [O] | 

 A good game:

1. e4 e5
2. Nf3 Nc6
3. Bb5 a6
4. Ba4 Nf6

llama_print_timings:        load time =  1144.33 ms
llama_print_timings:      sample time =    35.87 ms /    32 runs   (    1.12 ms per token)
llama_print_timings: prompt eval time =  1126.34 ms /     7 tokens (  160.91 ms per token)
llama_print_timings:        eval time =  5214.99 ms /    31 runs   (  168.23 ms per token)
llama_print_timings:       total time =  6398.45 ms
"Chess" without grammar
% ./main -m $LLAMA_30B_Q4_0 -n 32 -p $'A good game:\n\n'  

main: build = 645 (fd0eb66)
main: seed  = 1686286016
llama.cpp: loading model from /Users/evan/llama-models/30B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 6656
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 52
llama_model_load_internal: n_layer    = 60
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 17920
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 30B
llama_model_load_internal: ggml ctx size =    0.13 MB
llama_model_load_internal: mem required  = 19756.66 MB (+ 3124.00 MB per state)
.
llama_init_from_file: kv self size  =  780.00 MB

system_info: n_threads = 8 / 12 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 512, n_batch = 512, n_predict = 32, n_keep = 0


 A good game:

Sir Thomas Gresham, when he was building his famous Exchange at London, had the following dialogue with a mason, whose name was Richard B
llama_print_timings:        load time =  1185.47 ms
llama_print_timings:      sample time =    21.57 ms /    32 runs   (    0.67 ms per token)
llama_print_timings: prompt eval time =  1167.67 ms /     7 tokens (  166.81 ms per token)
llama_print_timings:        eval time =  4977.97 ms /    31 runs   (  160.58 ms per token)
llama_print_timings:       total time =  6188.21 ms
Arithmetic
 % ./main -m $LLAMA_30B_Q4_0 -n 32 -p $'Some arithmetic practice:\n\n' \                      
--grammar 'root  ::= (expr "=" ws num "\n")+
expr  ::= term ([-+*/] term)*
term  ::= ident | num | "(" ws expr ")" ws
ident ::= [a-z] [a-z0-9_]* ws
num   ::= [0-9]+ ws
ws    ::= [ \t\n]*'
main: build = 674 (e550234)
main: seed  = 1688014196
llama.cpp: loading model from /Users/evan/llama-models/30B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 6656
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 52
llama_model_load_internal: n_layer    = 60
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 17920
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 30B
llama_model_load_internal: ggml ctx size =    0.13 MB
llama_model_load_internal: mem required  = 19756.66 MB (+ 3124.00 MB per state)
.
llama_init_from_file: kv self size  =  780.00 MB

system_info: n_threads = 8 / 12 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 512, n_batch = 512, n_predict = 32, n_keep = 0


main: grammar:
root ::= root_5 
root_1 ::= expr [=] ws num [<U+000A>] 
expr ::= term expr_8 
ws ::= ws_12 
num ::= num_11 ws 
root_5 ::= root_1 root_5 | root_1 
term ::= ident | num | [(] ws expr [)] ws 
expr_7 ::= [-+*/] term 
expr_8 ::= expr_7 expr_8 | 
ident ::= [a-z] ident_10 ws 
ident_10 ::= [a-z0-9_] ident_10 | 
num_11 ::= [0-9] num_11 | [0-9] 
ws_12 ::= [ <U+0009><U+000A>] ws_12 | 

 Some arithmetic practice:

10 *a*1 +b*2 =640

10 *a*2 +b*3 =656


llama_print_timings:        load time =  1165.00 ms
llama_print_timings:      sample time =    41.11 ms /    32 runs   (    1.28 ms per token)
llama_print_timings: prompt eval time =  1147.76 ms /     7 tokens (  163.97 ms per token)
llama_print_timings:        eval time =  5113.92 ms /    31 runs   (  164.97 ms per token)
llama_print_timings:       total time =  6323.27 ms
Arithmetic - no grammar
 % ./main -m $LLAMA_30B_Q4_0 -n 32 -p $'Some arithmetic practice:\n\n'                                            
main: build = 645 (fd0eb66)
main: seed  = 1686286388
llama.cpp: loading model from /Users/evan/llama-models/30B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 6656
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 52
llama_model_load_internal: n_layer    = 60
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 17920
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 30B
llama_model_load_internal: ggml ctx size =    0.13 MB
llama_model_load_internal: mem required  = 19756.66 MB (+ 3124.00 MB per state)
.
llama_init_from_file: kv self size  =  780.00 MB

system_info: n_threads = 8 / 12 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 512, n_batch = 512, n_predict = 32, n_keep = 0


 Some arithmetic practice:

\begin{code}
package main

import (
    "fmt"
)

func main() {
    fmt.Println(
llama_print_timings:        load time =  1171.65 ms
llama_print_timings:      sample time =    21.37 ms /    32 runs   (    0.67 ms per token)
llama_print_timings: prompt eval time =  1153.88 ms /     7 tokens (  164.84 ms per token)
llama_print_timings:        eval time =  4991.68 ms /    31 runs   (  161.02 ms per token)
llama_print_timings:       total time =  6187.91 ms
JSON
% ./main -m $LLAMA_30B_Q4_0 -n 64 -p $'A bit about me:\n\n' --grammar-file grammars/json.gbnf
main: build = 674 (e550234)
main: seed  = 1688014289
llama.cpp: loading model from /Users/evan/llama-models/30B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 6656
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 52
llama_model_load_internal: n_layer    = 60
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 17920
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 30B
llama_model_load_internal: ggml ctx size =    0.13 MB
llama_model_load_internal: mem required  = 19756.66 MB (+ 3124.00 MB per state)
.
llama_init_from_file: kv self size  =  780.00 MB

system_info: n_threads = 8 / 12 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 512, n_batch = 512, n_predict = 64, n_keep = 0


main: grammar:
root ::= object 
object ::= [{] ws object_11 [}] 
value ::= object | array | string | number | boolean 
array ::= [[] ws array_15 []] 
string ::= ["] string_16 ["] ws 
number ::= number_17 number_18 ws 
boolean ::= boolean_19 ws 
ws ::= [ <U+0009><U+000A>] ws | 
object_8 ::= string [:] ws value object_10 
object_9 ::= [,] ws string [:] ws value 
object_10 ::= object_9 object_10 | 
object_11 ::= object_8 | 
array_12 ::= value array_14 
array_13 ::= [,] ws value 
array_14 ::= array_13 array_14 | 
array_15 ::= array_12 | 
string_16 ::= [ <U+0009>!#-[]-~] string_16 | 
number_17 ::= [-] | 
number_18 ::= [0-9] number_18 | [0-9] 
boolean_19 ::= [t] [r] [u] [e] | [f] [a] [l] [s] [e] 

 A bit about me:

{
	"fullName": "Ramon Rodriguez",
	"username": "ramon",
	"email": "ramon@mail.com",
	"phoneNumber": "+1234567890",
	"address": {
		
llama_print_timings:        load time =  1273.70 ms
llama_print_timings:      sample time =    82.93 ms /    64 runs   (    1.30 ms per token)
llama_print_timings: prompt eval time =  1256.36 ms /     8 tokens (  157.04 ms per token)
llama_print_timings:        eval time = 10432.05 ms /    63 runs   (  165.59 ms per token)
llama_print_timings:       total time = 11795.36 ms
"JSON" - no grammar
 % ./main -m $LLAMA_30B_Q4_0 -n 32 -p $'A bit about me:\n\n'                                                                          
main: build = 645 (fd0eb66)
main: seed  = 1686286615
llama.cpp: loading model from /Users/evan/llama-models/30B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 6656
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 52
llama_model_load_internal: n_layer    = 60
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 17920
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 30B
llama_model_load_internal: ggml ctx size =    0.13 MB
llama_model_load_internal: mem required  = 19756.66 MB (+ 3124.00 MB per state)
.
llama_init_from_file: kv self size  =  780.00 MB

system_info: n_threads = 8 / 12 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 512, n_batch = 512, n_predict = 32, n_keep = 0


 A bit about me:

A former teacher, now a full-time writer. I am the author of two novels: _The Man in the Moon_ and _The Riddle
llama_print_timings:        load time =  1291.32 ms
llama_print_timings:      sample time =    21.48 ms /    32 runs   (    0.67 ms per token)
llama_print_timings: prompt eval time =  1274.63 ms /     8 tokens (  159.33 ms per token)
llama_print_timings:        eval time =  4990.01 ms /    31 runs   (  160.97 ms per token)
llama_print_timings:       total time =  6306.01 ms
Japanese
 % ./main -m $LLAMA_30B_Q4_0 -n 32 -p $'Building a website can be done in 10 simple steps (from the original Japanese):\n\n' --grammar-file grammars/japanese.gbnf
main: build = 674 (e550234)
main: seed  = 1688013430
llama.cpp: loading model from /Users/evan/llama-models/30B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 6656
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 52
llama_model_load_internal: n_layer    = 60
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 17920
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 30B
llama_model_load_internal: ggml ctx size =    0.13 MB
llama_model_load_internal: mem required  = 19756.66 MB (+ 3124.00 MB per state)
.
llama_init_from_file: kv self size  =  780.00 MB

system_info: n_threads = 8 / 12 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 512, n_batch = 512, n_predict = 32, n_keep = 0


main: grammar:
root ::= root_2 root_5 
jp-char ::= hiragana | katakana | punctuation | cjk 
root_2 ::= jp-char root_2 | jp-char 
root_3 ::= [ <U+0009><U+000A>] root_4 
root_4 ::= jp-char root_4 | jp-char 
root_5 ::= root_3 root_5 | 
hiragana ::= [<U+3041>-<U+309F>] 
katakana ::= [<U+30A1>-<U+30FF>] 
punctuation ::= [<U+3001>-<U+303E>] 
cjk ::= [<U+4E00>-<U+9FFF>] 

 Building a website can be done in 10 simple steps (from the original Japanese):

一、目的は何なのか
二、お客さまを思い出して
三、お客さまのこと
llama_print_timings:        load time =  2957.19 ms
llama_print_timings:      sample time =    42.67 ms /    32 runs   (    1.33 ms per token)
llama_print_timings: prompt eval time =  2941.56 ms /    21 tokens (  140.07 ms per token)
llama_print_timings:        eval time =  5384.28 ms /    31 runs   (  173.69 ms per token)
llama_print_timings:       total time =  8387.61 ms
Japanese - no grammar
% ./main -m $LLAMA_30B_Q4_0 -n 32 -p $'Building a website can be done in 10 simple steps (from the original Japanese):\n\n' 
main: build = 674 (e550234)
main: seed  = 1688013483
llama.cpp: loading model from /Users/evan/llama-models/30B/ggml-model-q4_0.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 6656
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 52
llama_model_load_internal: n_layer    = 60
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 17920
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 30B
llama_model_load_internal: ggml ctx size =    0.13 MB
llama_model_load_internal: mem required  = 19756.66 MB (+ 3124.00 MB per state)
.
llama_init_from_file: kv self size  =  780.00 MB

system_info: n_threads = 8 / 12 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 512, n_batch = 512, n_predict = 32, n_keep = 0


 Building a website can be done in 10 simple steps (from the original Japanese):

1. Determine your goal for your site.
2. Make a plan.
3. Select the domain name.
4. Choose web
llama_print_timings:        load time =  2955.05 ms
llama_print_timings:      sample time =    22.96 ms /    32 runs   (    0.72 ms per token)
llama_print_timings: prompt eval time =  2937.10 ms /    21 tokens (  139.86 ms per token)
llama_print_timings:        eval time =  5032.41 ms /    31 runs   (  162.34 ms per token)
llama_print_timings:       total time =  8013.71 ms

Approach

Grammar API

The llama API accepts a data structure representing a context-free grammar over 32-bit code points:

    // grammar element type
    enum llama_gretype {
        // end of rule definition
        LLAMA_GRETYPE_END            = 0,

        // start of alternate definition for rule
        LLAMA_GRETYPE_ALT            = 1,

        // non-terminal element: reference to rule
        LLAMA_GRETYPE_RULE_REF       = 2,

        // terminal element: character (code point)
        LLAMA_GRETYPE_CHAR           = 3,

        // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
        // be an inclusive range ([a-z])
        LLAMA_GRETYPE_CHAR_RNG_UPPER = 4,

        // modifies a preceding LLAMA_GRETYPE_CHAR or
        // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
        LLAMA_GRETYPE_CHAR_ALT       = 5,
    };

    typedef struct llama_grammar_element {
        enum llama_gretype type;
        uint32_t           value; // Unicode code point or rule ID
    } llama_grammar_element;

    LLAMA_API struct llama_grammar * llama_grammar_init(
            const llama_grammar_element ** rules,
                                 size_t    n_rules,
                                 size_t    start_rule_index);

Sampling

The grammar sampling code models a nondeterministic pushdown automaton, maintaining N stacks for the possible parse states. Sampling a token is done in two steps: a sampling API that filters candidates to those that match one of the parse stacks (llama_sample_grammar) and adding the chose token to the grammar (llama_grammar_accept_token).

Examples

Adds --grammar and --grammar-file arguments to main taking a simple extended BNF to constrain generations. The parser for this format is implemented in examples/grammar-parser.{h,cpp}:

// ... Supports character
// ranges, grouping, and repetition operators. As an example, a grammar for
// arithmetic might look like:
//
// root  ::= expr
// expr  ::= term ([-+*/] term)*
// term  ::= num | "(" space expr ")" space
// num   ::= [0-9]+ space
// space ::= [ \t\n]*

The root rule identifies the start of the grammar.

## Caveats

  • the binary format makes the code harder to understand and more brittle
  • the grammar contemplates 16-bit chars but it's just being applied to the 8-bit UTF-8 chars in token strings currently
  • the 1-char lookahead sampling is probably biasing generations in a weird way; further investigation on quality of outputs is probably needed

@ggerganov ggerganov added the high priority Very important issue label Jun 9, 2023
@howard0su
Copy link
Collaborator

Suggest taking a file as grammar parameter and put several examples like what we did for prompts (in .\prompts folder).

@tobi
Copy link
Sponsor Collaborator

tobi commented Jun 10, 2023

Incredibly useful contribution. It's really amazing how much this simplifies many use cases.

I agree that it would be better if the grammar came from a file.

Two snags I hit while trying this out:

  • it crashes with --prompt-cache
  • any empty lines in the grammar cause a crash

Some additional thoughts:

  • Would love to have the grammars support empty lines and comments
  • I wonder if the grammar could be compiled into a tensor of state transitions and run on the GPU
  • I wonder if there is an optimization where the next token is already known form the grammar we could skip the inference and just add it? In many types of grammars like json or html that could really speed up generation
  • I think it's worth allowing to reference full tokens form the grammar. Maybe something like @“ token” or @13432 Id of token.

@slaren
Copy link
Collaborator

slaren commented Jun 11, 2023

Very nice! I am wondering what is the rationale for not including the parser in the llama.cpp API. Without it, most downstream users will be forced to manually make a copy of the parser in their code to support the feature, which is not great.
Also for usability, I think it would be a good idea to keep a copy of the binary grammar in llama_grammar, rather than asking the users to keep the provided copy alive. The overhead would be minimal, and it would simplify the code of downstream users.

@ejones
Copy link
Collaborator Author

ejones commented Jun 12, 2023

Thanks all! Just added support for grammar files (with examples) and updated the grammar syntax to add shell-style comments and allow empty lines between rules, as well as newlines inside parenthesized groups.

it crashes with --prompt-cache

I wonder if that was #1699 ? If so, should be fixed now

I wonder if the grammar could be compiled into a tensor of state transitions and run on the GPU

Sounds cool, I don't know enough about GPU programming to comment on that myself. The grammar participates in the sampling layer, and I'm not sure if that leverages the GPU currently.

I wonder if there is an optimization where the next token is already known form the grammar we could skip the inference and just add it?

This is definitely possible. That said, AFAIK the token would still need to be evaluated, and that seems to be the bottleneck. Maybe the optimization comes in being able to batch eval strings of such tokens?

I think it's worth allowing to reference full tokens form the grammar

Neat idea. Would that be more of an optimization or to reference tokens that can't be expressed textually?

what is the rationale for not including the parser in the llama.cpp API.

Honestly, I was trying to reduce the changes to llama.cpp itself. Agree it would be more convenient in the API.

I think it would be a good idea to keep a copy of the binary grammar

Makes sense. I left that out of this round of changes - if it's desired to have the grammar parser in the llama API, this may naturally fit with that change.

@bullno1
Copy link
Contributor

bullno1 commented Jun 12, 2023

First, this is amazing work.

This makes me wonder whether the entire sampling API should be pulled into something like llama_samplers instead.
External samplers can evolve independently of the core API.

The existing functions can be kept for compatibility.
AFAIK, the only thing we need is to expose the RNG.
And even then, the existence of that inside a state/context is debatable.
The context window is already managed by user code so why not sampling?

This reminds me a lot of: https://lmql.ai/.
There is also https://github.com/1rgs/jsonformer where the input is a json schema which is not always easy to express in BNF.

AFAIK the token would still need to be evaluated

Would it though?
We just immediately add it to the context.
It is done manually in user code now.

Maybe the optimization comes in being able to batch eval strings of such tokens?

AFAIK, that's the case.
The initial prompt and the user input are submitted in a large batch.
The inference loop just feed the single chosen token back until eos.

The grammar participates in the sampling layer, and I'm not sure if that leverages the GPU currently.

The current sampling is CPU.

@Green-Sky
Copy link
Collaborator

This makes me wonder whether the entire sampling API should be pulled into something like llama_samplers instead.

one of the discussion points for adding more llm generic tooling back into ggml(repo) was moving the sampler there. but afaik nothing happened yet :)

@ejones
Copy link
Collaborator Author

ejones commented Jun 12, 2023

There is also https://github.com/1rgs/jsonformer where the input is a json schema

Was planning to tackle this next. I've got it more or less working locally in a branch off of this, at least with the examples on jsonformer's README. It uses a Python script to generate a JSON BNF that conforms to the schema.

@@ -263,6 +289,9 @@ extern "C" {
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);

/// @details Apply constraints from grammar
LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make llama_grammar as a structure with two callbacks? So the other implementation of it can support context aware state machine instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean like, the caller would provide the implementation of llama_grammar (via callbacks), from which the llama API determines which tokens are valid?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, so llama code will not assume the grammar implementation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'm open to that idea, assuming the grammar interface itself generalizes well to other implementations. I kind of designed this with the specific implementation in mind so that's not a guarantee.

@ggerganov
Copy link
Owner

Great stuff!

I'm still wrapping my head around this.

  • Yes, this can become part of a llama.cpp or ggml sampling API, but I guess for now we can keep it as example and see what are the pros and cons and learn how to use it most efficiently
  • What happens when then next N > 1 tokens are uniquely determined by the grammar? I guess we will sample them one by one, correct? What would it take to make it so that they are submitted to be processed as a batch? This would significantly speed up the inference in such cases

@ejones
Copy link
Collaborator Author

ejones commented Jun 16, 2023

  • Yes, this can become part of a llama.cpp or ggml sampling API, but I guess for now we can keep it as example and see what are the pros and cons and learn how to use it most efficiently

To clarify, this PR adds the core sampling functionality in llama.cpp, leaving the grammar parser out in examples. Should that all be moved to examples or just left as is?

  • What happens when then next N > 1 tokens are uniquely determined by the grammar? I guess we will sample them one by one, correct? What would it take to make it so that they are submitted to be processed as a batch? This would significantly speed up the inference in such cases

Yes, that's correct. I think that's doable, I can take a stab at that.

@SlyEcho
Copy link
Sponsor Collaborator

SlyEcho commented Jun 16, 2023

the grammar contemplates 16-bit chars but it's just being applied to the 8-bit UTF-8 chars in token strings currently

I don't understand this part. So it is converting to UTF-16?

Another option would be to use token values but it will be more limiting.

EDIT: I read through the code.

The grammar doesn't care about the text encoding. It could work with any encoding, provided that the rules match the characters correctly.

The parser doesn't understand UTF-8 so it will create rules that don't match as the user expects.

For example, if I wanted to create a rule to match all Hiragana characters, I should be able to write:

[ぁ-ゖ]

However the parser doesn't see it as two characters separated by -, instead:

[\xe3\x81\x81-\xe3\x82\x96]

But the correct rule should be something like this?

"\xe3" [\x81-\x82] [\x81-\x96]

llama.cpp Outdated Show resolved Hide resolved
@ivanstepanovftw
Copy link
Collaborator

Just dont use repeat penalties to get best grammar as llama can

@ggerganov
Copy link
Owner

To clarify, this PR adds the core sampling functionality in llama.cpp, leaving the grammar parser out in examples. Should that all be moved to examples or just left as is?

It's fine the way it is

@burke
Copy link
Sponsor

burke commented Jun 16, 2023

FWIW I'm adapting this code into an analogous feature for models running on torch. In my implementation, I'm doing grammar enforcement logit masking on the GPU across the full token set before selecting candidates: https://github.com/Shopify/torch-grammar/blob/df23e354083c909c70120e256ed34036c93f6714/grammar_sampler.py#L232-L239. The same strategy would probably work here if anyone was super motivated to try it.

llama.h Outdated Show resolved Hide resolved
@ggerganov
Copy link
Owner

Adding a similar grammar-based sampling to whisper.cpp would be a really cool contribution.
Here is one ongoing implementation that can potentially benefit from such functionality by using VIM-based grammar:

ggerganov/whisper.cpp#1144

@x4080
Copy link

x4080 commented Aug 12, 2023

Can we improve the result by fine tuning the model ? If so, what is the example to improve it ?
TIA

@ejones
Copy link
Collaborator Author

ejones commented Aug 14, 2023

@ggerganov agreed! Although I'm not sure when or if I'll be able to contribute that.

@x4080 this approach is independent of the model variant and can be used with a fine-tune. In the comments above there's a demonstration of using WizardLM, for example.

@x4080
Copy link

x4080 commented Aug 14, 2023

@ejones thanks

@RevanthRameshkumar
Copy link

What is the generation speed on this? How does it compare to unconstrained generation when using cpp?

I am trying to replicate this in python using an A100 on an 8bit quantized llama 7b and it is extremely slow per token (compared to unconstrained generation) due to all the extra encoding and decoding that needs to happen.

@ejones
Copy link
Collaborator Author

ejones commented Aug 27, 2023

I've only done CPU inference, but the performance impact has been insignificant for everything I've tried. On the M2 Max I'm seeing about ~0.5 ms / token sampling for unconstrained vs ~6ms with a grammar, with token eval taking about ~70ms for 13b (Q4_K).

I'm not sure if I know enough about GPU programming to meaningfully comment; I know that other folks are working on approaches that are more generic or are GPU- and/or Python-focused. There are some examples of this upthread: https://github.com/Shopify/torch-grammar and https://github.com/normal-computing/outlines.

@ejones
Copy link
Collaborator Author

ejones commented Aug 31, 2023

@ggerganov alright, I've done it: ggerganov/whisper.cpp#1229

@kalomaze
Copy link
Contributor

kalomaze commented Sep 23, 2023

Is it a natural side effect of the grammar sampling method that I am seeing a significant degradation in tokens per second speed during generation, even with permissive sampling rules?
On 13b, I go from ~20T/s to ~13T/s, which is pretty unfortunate in my eyes, but is it expected? Is it CPU bound potentially?

@Green-Sky
Copy link
Collaborator

It tells you the time spent on sampling at the end. You can confirm it there.
eg:

llama_print_timings:        load time =   285,92 ms
llama_print_timings:      sample time =    30,81 ms /    37 runs   (    0,83 ms per token,  1200,95 tokens per second)
llama_print_timings: prompt eval time =   159,23 ms /     5 tokens (   31,85 ms per token,    31,40 tokens per second)
llama_print_timings:        eval time =  3656,10 ms /    36 runs   (  101,56 ms per token,     9,85 tokens per second)
llama_print_timings:       total time =  3861,11 ms

@ejones
Copy link
Collaborator Author

ejones commented Sep 26, 2023

Yeah, I generally see about ~5ms/token overhead for grammars on the M2 Max, which is usually a fraction of the per token eval time. But recently I was testing with a grammar and saw a more significant impact. Should investigate, there may be some pathological cases.

lenaxia pushed a commit to lenaxia/home-ops-prod that referenced this pull request Apr 27, 2024
…d grammars by including the `messages` field and adjusting the endpoint to `/v1/chat/completions`.

# Aider chat conversation:

USER: https://localai.io/features/constrained_grammars/:

-   [](/)

    LocalAI

-

-

-   [*info* Overview](https://localai.io/)
-   *rocket_launch* Getting started

    -   [Quickstart](https://localai.io/basics/getting_started/)
    -   [Run other
        Models](https://localai.io/docs/getting-started/run-other-models/)
    -   [Customizing the
        Model](https://localai.io/docs/getting-started/customize-model/)
    -   [Run models
        manually](https://localai.io/docs/getting-started/manual/)
    -   [Build LocalAI from source](https://localai.io/basics/build/)

-   [*newspaper* News](https://localai.io/basics/news/)
-   *feature_search* Features

    -   [⚡ GPU
        acceleration](https://localai.io/features/gpu-acceleration/)
    -   [📖 Text generation
        (GPT)](https://localai.io/features/text-generation/)
    -   [🗣 Text to audio
        (TTS)](https://localai.io/features/text-to-audio/)
    -   [🎨 Image
        generation](https://localai.io/features/image-generation/)
    -   [🧠 Embeddings](https://localai.io/features/embeddings/)
    -   [🆕 GPT Vision](https://localai.io/features/gpt-vision/)
    -   [✍️ Constrained
        grammars](https://localai.io/features/constrained_grammars/)
    -   [🔈 Audio to text](https://localai.io/features/audio-to-text/)
    -   [🔥 OpenAI functions and
        tools](https://localai.io/features/openai-functions/)
    -   [💾 Stores](https://localai.io/stores/)
    -   [🖼️ Model gallery](https://localai.io/models/)

-   [*rocket_launch*
    Integrations](https://localai.io/docs/integrations/)
-   *science* Advanced

    -   [Advanced usage](https://localai.io/advanced/)
    -   [Fine-tuning LLMs for text
        generation](https://localai.io/docs/advanced/fine-tuning/)

-   *science* References

    -   [Model compatibility
        table](https://localai.io/model-compatibility/)
    -   [Architecture](https://localai.io/docs/reference/architecture/)
    -   [Available Container
        images](https://localai.io/docs/reference/container-images/)
    -   [All-In-One
        images](https://localai.io/docs/reference/aio-images/)

-   [*quiz* FAQ](https://localai.io/faq/)

::::::::::::::::::::::::::::::::::::::::::::::::: {role="main"}

[](/)

menu

search Search

[ ]{.kbd} [ ]{.kbd}

-   [](%20https://github.com/mudler/LocalAI%20)

-   [](%20https://twitter.com/LocalAI_API%20)

-   [](%20https://localai.io/index.xml%20)

[]{.kbd} []{.kbd} to navigate

[]{.kbd} to select

[]{.kbd} to close

cancel

-   [*Home*](/docs/)
-   [Features](/features/)
-   ✍️ Constrained grammars

On this page

-   -   [Setup](#setup)
    -   [💡 Usage example](#-usage-example)

Table of Contents

-   -   [Setup](#setup)
    -   [💡 Usage example](#-usage-example)

*article*

# ✍️ Constrained grammars

The chat endpoint accepts an additional `grammar` parameter which takes
a [BNF defined
grammar](https://en.wikipedia.org/wiki/Backus%E2%80%93Naur_form).

This allows the LLM to constrain the output to a user-defined schema,
allowing to generate `JSON`, `YAML`, and everything that can be defined
with a BNF grammar.

notifications

This feature works only with models compatible with the
[llama.cpp](https://github.com/ggerganov/llama.cpp) backend (see also
[Model compatibility](/model-compatibility/)). For details on how it
works, see the upstream PRs:
<ggerganov/llama.cpp#1773>,
<ggerganov/llama.cpp#1887>

## Setup [*link*](#setup)

Follow the setup instructions from the [LocalAI
functions](/features/openai-functions/) page.

## 💡 Usage example [*link*](#-usage-example)

For example, to constrain the output to either `yes`, `no`:

    curl http://localhost:8080/v1/chat/completions -H "Content-Type: application/json" -d '{
      "model": "gpt-4",
      "messages": [{"role": "user", "content": "Do you like apples?"}],
      "grammar": "root ::= (\"yes\" | \"no\")"
    }'

[ Edit this
page](https://github.com/mudler/LocalAI/blob/master/docs/content/docs/features/constrained_grammars.md)

Last updated 3 months ago. history

------------------------------------------------------------------------

[](/features/gpt-vision/)

*navigate_before* 🆕 GPT Vision

[](/features/audio-to-text/)

🔈 Audio to text *navigate_next*

© 2023-2024 [Ettore Di Giacinto](https://mudler.pm)

:::::::::::::::::::::::::::::::::::::::::::::::::

ASSISTANT: Based on the information provided about the LocalAI's constrained grammars feature, it seems that the `handleCompletions` function in the `rag-api.go` file needs to be updated to support the constrained grammars when making a request to the LocalAI service. The request payload should include a `grammar` field when it is provided in the `CompletionRequest`.

Here is the diff for the necessary changes:

```diff
--- cluster/apps/home/localai/app/config/rag-api.go
+++ cluster/apps/home/localai/app/config/rag-api.go
@@ ... @@
 func handleCompletions(w http.ResponseWriter, r *http.Request) {
 	atomic.AddUint64(&requestMetrics.CompletionRequests, 1)
 	logRequest(r)

 	var req CompletionRequest
 	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
 		http.Error(w, err.Error(), http.StatusBadRequest)
 		return
 	}
 	defer r.Body.Close()

 	embedding, err := getEmbeddings(req.Prompt)
 	if err != nil {
 		http.Error(w, err.Error(), http.StatusInternalServerError)
 		return
 	}

-	// Generate completion using constrained grammar (if requested)
-	var completion CompletionResponse
-	if req.ConstrainedGrammar != "" {
-		payload := map[string]interface{}{
-			"model":   "gpt-4",
-			"prompt":  req.Prompt,
-			"grammar": req.ConstrainedGrammar,
-			"max_tokens": req.MaxTokens,
-			"temperature": req.Temperature,
-			"top_p": req.TopP,
-		}
-		jsonPayload, err := json.Marshal(payload)
-		if err != nil {
-			http.Error(w, err.Error(), http.StatusInternalServerError)
-			return
-		}
+	// Generate completion using constrained grammar or retrieval-augmented generation (if requested)
+	var completionResp CompletionResponse
+	payload := map[string]interface{}{
+		"model":       "gpt-4",
+		"prompt":      req.Prompt,
+		"max_tokens":  req.MaxTokens,
+		"temperature": req.Temperature,
+		"top_p":       req.TopP,
+	}
+	if req.ConstrainedGrammar != "" {
+		payload["grammar"] = req.ConstrainedGrammar
+	}
+	jsonPayload, err := json.Marshal(payload)
+	if err != nil {
+		http.Error(w, err.Error(), http.StatusInternalServerError)
+		return
+	}

 	localAI := os.Getenv("LOCAL_AI_ENDPOINT")
 	if localAI == "" {
 		localAI = defaultLocalAI
 	}

-	resp, err := http.Post(localAI+"/v1/chat/completions", "application/json", bytes.NewBuffer(jsonPayload))
+	resp, err := http.Post(localAI+"/completions", "application/json", bytes.NewBuffer(jsonPayload))
 	if err != nil {
 		http.Error(w, err.Error(), http.StatusInternalServerError)
 		return
 	}
 	defer resp.Body.Close()

 	if resp.StatusCode != http.StatusOK {
 		body, err := io.ReadAll(resp.Body)
 		if err != nil {
 			http.Error(w, err.Error(), http.StatusInternalServerError)
 			return
 		}
 		http.Error(w, fmt.Sprintf("completion request failed with status code: %d: %s", resp.StatusCode, body), resp.StatusCode)
 		return
 	}

-	var respBody struct {
-		Result CompletionResponse `json:"result"`
-	}
-	if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil {
+	if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
 		http.Error(w, err.Error(), http.StatusInternalServerError)
 		return
 	}

-	completion = respBody.Result
+	jsonResp, err := json.Marshal(completionResp)
+	if err != nil {
+		http.Error(w, err.Error(), http.StatusInternalServerError)
+		return
+	}

-	jsonResp, err := json.Marshal(completion)
-	if err != nil {
-		http.Error(w, err.Error(), http.StatusInternalServerError)
-		return
-	}

 	w.Header().Set("Content-Type", "application/json")
 	w.WriteHeader(http.StatusOK)
 	w.Write(jsonResp)
 }
```

This diff updates the `handleCompletions` function to include the `grammar` field in the request payload if it is provided in the `CompletionRequest`. It also corrects the endpoint to `/completions` as per the LocalAI documentation.

USER: UnifiedDiffNoMatch: hunk failed to apply!

cluster/apps/home/localai/app/config/rag-api.go does not contain lines that match the diff you provided!
Try again.
DO NOT skip blank lines, comments, docstrings, etc!
The diff needs to apply cleanly to the lines in cluster/apps/home/localai/app/config/rag-api.go!

cluster/apps/home/localai/app/config/rag-api.go does not contain these 76 exact lines in a row:
```
func handleCompletions(w http.ResponseWriter, r *http.Request) {
	atomic.AddUint64(&requestMetrics.CompletionRequests, 1)
	logRequest(r)

	var req CompletionRequest
	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
		http.Error(w, err.Error(), http.StatusBadRequest)
		return
	}
	defer r.Body.Close()

	embedding, err := getEmbeddings(req.Prompt)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	// Generate completion using constrained grammar (if requested)
	var completion CompletionResponse
	if req.ConstrainedGrammar != "" {
		payload := map[string]interface{}{
			"model":   "gpt-4",
			"prompt":  req.Prompt,
			"grammar": req.ConstrainedGrammar,
			"max_tokens": req.MaxTokens,
			"temperature": req.Temperature,
			"top_p": req.TopP,
		}
		jsonPayload, err := json.Marshal(payload)
		if err != nil {
			http.Error(w, err.Error(), http.StatusInternalServerError)
			return
		}

	localAI := os.Getenv("LOCAL_AI_ENDPOINT")
	if localAI == "" {
		localAI = defaultLocalAI
	}

	resp, err := http.Post(localAI+"/v1/chat/completions", "application/json", bytes.NewBuffer(jsonPayload))
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusOK {
		body, err := io.ReadAll(resp.Body)
		if err != nil {
			http.Error(w, err.Error(), http.StatusInternalServerError)
			return
		}
		http.Error(w, fmt.Sprintf("completion request failed with status code: %d: %s", resp.StatusCode, body), resp.StatusCode)
		return
	}

	var respBody struct {
		Result CompletionResponse `json:"result"`
	}
	if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	completion = respBody.Result

	jsonResp, err := json.Marshal(completion)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	w.Header().Set("Content-Type", "application/json")
	w.WriteHeader(http.StatusOK)
	w.Write(jsonResp)
}
```

ASSISTANT: I apologize for the confusion. Let's try again with the correct diff that applies cleanly to the `cluster/apps/home/localai/app/config/rag-api.go` file. I will update the `handleCompletions` function to correctly handle the constrained grammars feature as described in the LocalAI documentation.

```diff
--- cluster/apps/home/localai/app/config/rag-api.go
+++ cluster/apps/home/localai/app/config/rag-api.go
@@ ... @@
 func handleCompletions(w http.ResponseWriter, r *http.Request) {
 	atomic.AddUint64(&requestMetrics.CompletionRequests, 1)
 	logRequest(r)

 	var req CompletionRequest
 	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
 		http.Error(w, err.Error(), http.StatusBadRequest)
 		return
 	}
 	defer r.Body.Close()

 	// Generate completion using constrained grammar (if requested)
 	var completion CompletionResponse
 	if req.ConstrainedGrammar != "" {
 		payload := map[string]interface{}{
 			"model":   "gpt-4",
 			"prompt":  req.Prompt,
-			"grammar": req.ConstrainedGrammar,
+			"grammar": req.ConstrainedGrammar,
 			"max_tokens": req.MaxTokens,
 			"temperature": req.Temperature,
 			"top_p": req.TopP,
 		}
 		jsonPayload, err := json.Marshal(payload)
 		if err != nil {
 			http.Error(w, err.Error(), http.StatusInternalServerError)
 			return
 		}

 		localAI := os.Getenv("LOCAL_AI_ENDPOINT")
 		if localAI == "" {
 			localAI = defaultLocalAI
 		}

-		resp, err := http.Post(localAI+"/v1/chat/completions", "application/json", bytes.NewBuffer(jsonPayload))
+		resp, err := http.Post(localAI+"/completions", "application/json", bytes.NewBuffer(jsonPayload))
 		if err != nil {
 			http.Error(w, err.Error(), http.StatusInternalServerError)
 			return
 		}
 		defer resp.Body.Close()

 		if resp.StatusCode != http.StatusOK {
 			body, err := io.ReadAll(resp.Body)
 			if err != nil {
 				http.Error(w, err.Error(), http.StatusInternalServerError)
 				return
 			}
 			http.Error(w, fmt.Sprintf("completion request failed with status code: %d: %s", resp.StatusCode, body), resp.StatusCode)
 			return
 		}

 		var respBody struct {
 			Result CompletionResponse `json:"result"`
 		}
 		if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil {
 			http.Error(w, err.Error(), http.StatusInternalServerError)
 			return
 		}

 		completion = respBody.Result
 	} else {
 		// Implement retrieval-augmented generation using retrieved data
 		// ...
 	}

 	jsonResp, err := json.Marshal(completion)
 	if err != nil {
 		http.Error(w, err.Error(), http.StatusInternalServerError)
 		return
 	}

 	w.Header().Set
^C KeyboardInterrupt

USER: https://localai.io/features/constrained_grammars/:

-   [](/)

    LocalAI

-

-

-   [*info* Overview](https://localai.io/)
-   *rocket_launch* Getting started

    -   [Quickstart](https://localai.io/basics/getting_started/)
    -   [Run other
        Models](https://localai.io/docs/getting-started/run-other-models/)
    -   [Customizing the
        Model](https://localai.io/docs/getting-started/customize-model/)
    -   [Run models
        manually](https://localai.io/docs/getting-started/manual/)
    -   [Build LocalAI from source](https://localai.io/basics/build/)

-   [*newspaper* News](https://localai.io/basics/news/)
-   *feature_search* Features

    -   [⚡ GPU
        acceleration](https://localai.io/features/gpu-acceleration/)
    -   [📖 Text generation
        (GPT)](https://localai.io/features/text-generation/)
    -   [🗣 Text to audio
        (TTS)](https://localai.io/features/text-to-audio/)
    -   [🎨 Image
        generation](https://localai.io/features/image-generation/)
    -   [🧠 Embeddings](https://localai.io/features/embeddings/)
    -   [🆕 GPT Vision](https://localai.io/features/gpt-vision/)
    -   [✍️ Constrained
        grammars](https://localai.io/features/constrained_grammars/)
    -   [🔈 Audio to text](https://localai.io/features/audio-to-text/)
    -   [🔥 OpenAI functions and
        tools](https://localai.io/features/openai-functions/)
    -   [💾 Stores](https://localai.io/stores/)
    -   [🖼️ Model gallery](https://localai.io/models/)

-   [*rocket_launch*
    Integrations](https://localai.io/docs/integrations/)
-   *science* Advanced

    -   [Advanced usage](https://localai.io/advanced/)
    -   [Fine-tuning LLMs for text
        generation](https://localai.io/docs/advanced/fine-tuning/)

-   *science* References

    -   [Model compatibility
        table](https://localai.io/model-compatibility/)
    -   [Architecture](https://localai.io/docs/reference/architecture/)
    -   [Available Container
        images](https://localai.io/docs/reference/container-images/)
    -   [All-In-One
        images](https://localai.io/docs/reference/aio-images/)

-   [*quiz* FAQ](https://localai.io/faq/)

::::::::::::::::::::::::::::::::::::::::::::::::: {role="main"}

[](/)

menu

search Search

[ ]{.kbd} [ ]{.kbd}

-   [](%20https://github.com/mudler/LocalAI%20)

-   [](%20https://twitter.com/LocalAI_API%20)

-   [](%20https://localai.io/index.xml%20)

[]{.kbd} []{.kbd} to navigate

[]{.kbd} to select

[]{.kbd} to close

cancel

-   [*Home*](/docs/)
-   [Features](/features/)
-   ✍️ Constrained grammars

On this page

-   -   [Setup](#setup)
    -   [💡 Usage example](#-usage-example)

Table of Contents

-   -   [Setup](#setup)
    -   [💡 Usage example](#-usage-example)

*article*

# ✍️ Constrained grammars

The chat endpoint accepts an additional `grammar` parameter which takes
a [BNF defined
grammar](https://en.wikipedia.org/wiki/Backus%E2%80%93Naur_form).

This allows the LLM to constrain the output to a user-defined schema,
allowing to generate `JSON`, `YAML`, and everything that can be defined
with a BNF grammar.

notifications

This feature works only with models compatible with the
[llama.cpp](https://github.com/ggerganov/llama.cpp) backend (see also
[Model compatibility](/model-compatibility/)). For details on how it
works, see the upstream PRs:
<ggerganov/llama.cpp#1773>,
<ggerganov/llama.cpp#1887>

## Setup [*link*](#setup)

Follow the setup instructions from the [LocalAI
functions](/features/openai-functions/) page.

## 💡 Usage example [*link*](#-usage-example)

For example, to constrain the output to either `yes`, `no`:

    curl http://localhost:8080/v1/chat/completions -H "Content-Type: application/json" -d '{
      "model": "gpt-4",
      "messages": [{"role": "user", "content": "Do you like apples?"}],
      "grammar": "root ::= (\"yes\" | \"no\")"
    }'

[ Edit this
page](https://github.com/mudler/LocalAI/blob/master/docs/content/docs/features/constrained_grammars.md)

Last updated 3 months ago. history

------------------------------------------------------------------------

[](/features/gpt-vision/)

*navigate_before* 🆕 GPT Vision

[](/features/audio-to-text/)

🔈 Audio to text *navigate_next*

© 2023-2024 [Ettore Di Giacinto](https://mudler.pm)

:::::::::::::::::::::::::::::::::::::::::::::::::

ASSISTANT: To implement the constrained grammars feature as described in the LocalAI documentation, we need to update the `handleCompletions` function in the `rag-api.go` file to correctly handle the `grammar` parameter in the request payload. Here is the diff for the necessary changes:

```diff
--- cluster/apps/home/localai/app/config/rag-api.go
+++ cluster/apps/home/localai/app/config/rag-api.go
@@ ... @@
 func handleCompletions(w http.ResponseWriter, r *http.Request) {
 	atomic.AddUint64(&requestMetrics.CompletionRequests, 1)
 	logRequest(r)

 	var req CompletionRequest
 	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
 		http.Error(w, err.Error(), http.StatusBadRequest)
 		return
 	}
 	defer r.Body.Close()

 	// Generate completion using constrained grammar (if requested)
 	var completion CompletionResponse
 	if req.ConstrainedGrammar != "" {
 		payload := map[string]interface{}{
 			"model":   "gpt-4",
 			"prompt":  req.Prompt,
+			"messages": []map[string]string{
+				{
+					"role":    "user",
+					"content": req.Prompt,
+				},
+			},
 			"max_tokens": req.MaxTokens,
 			"temperature": req.Temperature,
 			"top_p": req.TopP,
+			"grammar": req.ConstrainedGrammar,
 		}
 		jsonPayload, err := json.Marshal(payload)
 		if err != nil {
 			http.Error(w, err.Error(), http.StatusInternalServerError)
 			return
 		}

 		localAI := os.Getenv("LOCAL_AI_ENDPOINT")
 		if localAI == "" {
 			localAI = defaultLocalAI
 		}

-		resp, err := http.Post(localAI+"/completions", "application/json", bytes.NewBuffer(jsonPayload))
+		resp, err := http.Post(localAI+"/v1/chat/completions", "application/json", bytes.NewBuffer(jsonPayload))
 		if err != nil {
 			http.Error(w, err.Error(), http.StatusInternalServerError)
 			return
 		}
 		defer resp.Body.Close()

 		if resp.StatusCode != http.StatusOK {
 			body, err := io.ReadAll(resp.Body)
 			if err != nil {
 				http.Error(w, err.Error(), http.StatusInternalServerError)
 				return
 			}
 			http.Error(w, fmt.Sprintf("completion request failed with status code: %d: %s", resp.StatusCode, body), resp.StatusCode)
 			return
 		}

 		var respBody struct {
 			Result CompletionResponse `json:"result"`
 		}
 		if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil {
 			http.Error(w, err.Error(), http.StatusInternalServerError)
 			return
 		}

 		completion = respBody.Result
 	} else {
 		// Implement retrieval-augmented generation using retrieved data
 		// ...
 	}

 	jsonResp, err := json.Marshal(completion)
 	if err != nil {
 		http.Error(w, err.Error(), http.StatusInternalServerError)
 		return
 	}

 	w.Header().Set("Content-Type", "application/json")
 	w.WriteHeader(http.StatusOK)
 	w.Write(jsonResp)
 }
```

This diff adds the `messages` field to the payload as required by the LocalAI API for constrained grammars and updates the endpoint to `/v1/chat/completions`. It also includes the `grammar` field when the `ConstrainedGrammar` is not empty.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request generation quality Quality of model output high priority Very important issue
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

None yet