Skip to content

Commit

Permalink
adding cacheing and new test
Browse files Browse the repository at this point in the history
  • Loading branch information
richardanaya committed Jun 1, 2024
1 parent 1181b8a commit 43805ff
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ curl http://localhost:11434/api/generate -d '{

#### Request (GBNF mode)

> When `grammar` is set to a [GBNF grammar](https://github.com/ggerganov/llama.cpp/tree/master/grammars) output will be constrained to the grammar's rules. This method does not rely upon the prompt containing references to how it should output.
> When `grammar` is set to a [GBNF grammar](https://github.com/ggerganov/llama.cpp/tree/master/grammars) output will be constrained to the grammar's rules. This method does not rely upon the prompt containing references to how it should output. Before you call with this property, the environment variable `OLLAMA_GRAMMAR=true` must be.
##### Request

Expand Down
45 changes: 43 additions & 2 deletions llm/grammar.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,15 @@ import (
"bufio"
"fmt"
"strings"

"github.com/ollama/ollama/format"
)

var maxGrammarSize = 32 * format.KiloByte

// a cache that stores max 100 grammars
var grammarValidationCache = make(map[string]error)

func findIndexOfTextNotInQuotesOrCharacterSet(input string, text string) int {
quoteBalance := 0
bracketBalance := 0
Expand Down Expand Up @@ -244,6 +251,11 @@ func validateCharacterClass(charClass string) error {
func validateStringLiteral(strLiteral string) error {
validEscapeCharacters := "\\\"ntrxu"

// make sure the string literal starts and ends with a quote
if len(strLiteral) < 2 || strLiteral[0] != '"' || strLiteral[len(strLiteral)-1] != '"' {
return fmt.Errorf("string literal must start and end with a quote")
}

i := 0
for i < len(strLiteral) {
if strLiteral[i] == '\\' {
Expand Down Expand Up @@ -463,7 +475,29 @@ func parseGrammar(grammar string) (map[string]([]Token), error) {
return ruleTokens, nil
}

func addToCache(grammar string, err error) {
if len(grammarValidationCache) >= 100 {
// remove the first element
for key := range grammarValidationCache {
delete(grammarValidationCache, key)
break
}
}
grammarValidationCache[grammar] = err
}

func ValidateGrammar(grammar string) error {
// check to see if we've cached this before and if so return it
if err, ok := grammarValidationCache[grammar]; ok {
return err
}

if len(grammar) > maxGrammarSize {
err := fmt.Errorf("grammar size exceeds maximum size of %d bytes", maxGrammarSize)
addToCache(grammar, err)
return err
}

// Since GBNF is essentially just a list of rules, we can validate the grammar by
// removing all comments, removing all non-essential white space
// and then breaking the input into an array of rules
Expand All @@ -474,6 +508,7 @@ func ValidateGrammar(grammar string) error {

ruleTokens, err := parseGrammar(grammar)
if err != nil {
addToCache(grammar, err)
return err
}

Expand All @@ -484,13 +519,19 @@ func ValidateGrammar(grammar string) error {

// check that it has root rule
if _, ok := definedRules["root"]; !ok {
return fmt.Errorf("no root rule defined")
err := fmt.Errorf("no root rule defined")
addToCache(grammar, err)
return err
}

for key, value := range ruleTokens {
if err := validateRule(value, definedRules); err != nil {
return fmt.Errorf("error in rule \"%s\": %v", key, err)
err = fmt.Errorf("error in rule \"%s\": %v", key, err)
addToCache(grammar, err)
return err
}
}

addToCache(grammar, nil)
return nil
}
9 changes: 9 additions & 0 deletions llm/grammar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,15 @@ func TestNoRoot(t *testing.T) {
}
}

func TestInvalidRoot(t *testing.T) {
// this is a common typo
input := `root ::= "yes`
err := ValidateGrammar(input)
if err == nil {
t.Errorf("Expected error validating grammar, got nil")
}
}

func TestBadLlama(t *testing.T) {
// this is a common typo
input := `root :== "yes" | "no"`
Expand Down
6 changes: 6 additions & 0 deletions llm/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ type llmServer struct {
gpuCount int
loadDuration time.Duration // Record how long it took the model to load
loadProgress float32
grammarEnabled bool

sem *semaphore.Weighted
}
Expand Down Expand Up @@ -306,6 +307,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
totalLayers: ggml.KV().BlockCount() + 1,
gpuCount: gpuCount,
done: make(chan error, 1),
grammarEnabled: os.Getenv("OLLAMA_GRAMMAR") == "true",
}

s.cmd.Env = os.Environ()
Expand Down Expand Up @@ -705,6 +707,10 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("grammar and format cannot be used together")
}

if !s.grammarEnabled {
return fmt.Errorf("grammar is specified, but OLLAMA_GRAMMAR is not enabled")
}

err := ValidateGrammar(req.Grammar)

if err != nil {
Expand Down

0 comments on commit 43805ff

Please sign in to comment.