Skip to content

Commit

Permalink
Exposing grammar as a request parameter in completion/chat with go-si…
Browse files Browse the repository at this point in the history
…de grammar validation
  • Loading branch information
richardanaya committed May 19, 2024
1 parent 105186a commit 24adac4
Show file tree
Hide file tree
Showing 6 changed files with 792 additions and 0 deletions.
6 changes: 6 additions & 0 deletions api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ type GenerateRequest struct {
// Format specifies the format to return a response in.
Format string `json:"format"`

// Grammar specifies the GBNF grammar string to conform a generation output.
Grammar string `json:"grammar"`

// KeepAlive controls how long the model will stay loaded in memory following
// this request.
KeepAlive *Duration `json:"keep_alive,omitempty"`
Expand Down Expand Up @@ -94,6 +97,9 @@ type ChatRequest struct {
// Format is the format to return the response in (e.g. "json").
Format string `json:"format"`

// Grammar specifies the GBNF grammar string to conform a generation output.
Grammar string `json:"grammar"`

// KeepAlive controls how long the model will stay loaded into memory
// followin the request.
KeepAlive *Duration `json:"keep_alive,omitempty"`
Expand Down
16 changes: 16 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Generate a response for a given prompt with a provided model. This is a streamin
Advanced parameters (optional):

- `format`: the format to return a response in. Currently the only accepted value is `json`
- `grammar`: the [GBNF grammar](https://github.com/ggerganov/llama.cpp/tree/master/grammars) to constrain generated output to
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `system`: system message to (overrides what is defined in the `Modelfile`)
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
Expand Down Expand Up @@ -162,6 +163,21 @@ curl http://localhost:11434/api/generate -d '{
}'
```

#### Request (GRAMMAR mode)

> When `grammar` is set to a [GBNF grammar](https://github.com/ggerganov/llama.cpp/tree/master/grammars). This method does not rely upon the prompt containing references to how it should output.
##### Request

```shell
curl http://localhost:11434/api/generate -d '{
"model": "llama3",
"prompt": "Are llamas amazing?",
"grammar": "root :== \"yes\" | \"no\"",
"stream": false
}'
```

##### Response

```json
Expand Down
324 changes: 324 additions & 0 deletions llm/grammar.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,324 @@
package llm

import (
"bufio"
"fmt"
"strings"
)

// removeComments removes comments from the input string. Comments are lines
// that start with # (with optional leading spaces) or # that appear at the end of a line.
func removeComments(input string) string {
var output strings.Builder
scanner := bufio.NewScanner(strings.NewReader(input))

for scanner.Scan() {
line := scanner.Text()

// Remove comments at the beginning of the line
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "#") {
continue
}

// Remove comments at the end of the line
if idx := strings.Index(line, "#"); idx != -1 {
line = line[:idx]
}

// Trim any trailing spaces from the cleaned line
line = strings.TrimSpace(line)

// Append the clean line to the output if it's not empty
if line != "" {
output.WriteString(line)
output.WriteString("\n")
}
}

if err := scanner.Err(); err != nil {
fmt.Println("Error reading input:", err)
}

return output.String()
}

func breakIntoArrayOfRules(input string) (map[string]string, error) {
var rules []string
scanner := bufio.NewScanner(strings.NewReader(input))

var currentRule strings.Builder
for scanner.Scan() {
line := scanner.Text()

// if the line matches the pattern, we start a new rule
if strings.Contains(line, "::=") {
if currentRule.Len() > 0 {
rules = append(rules, currentRule.String())
currentRule.Reset()
}
}

// append the line to the current rule
currentRule.WriteString(line)
currentRule.WriteString("\n")
}

if currentRule.Len() > 0 {
rules = append(rules, currentRule.String())
}

// remove all new lines from all rules
for i, rule := range rules {
rules[i] = strings.ReplaceAll(rule, "\n", " ")
}

// put rules into a map
rulesMap := make(map[string]string)

for _, rule := range rules {
// split the rule into the key and value
parts := strings.Split(rule, "::=")
if len(parts) != 2 {
return nil, fmt.Errorf("invalid rule did not contain exactly one ::= separator: %s", rule)
}
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])

// add the key and value to the map
rulesMap[key] = value
}

return rulesMap, nil
}

type TokenType int

const (
TokenUnknown TokenType = iota
TokenSpace
TokenPipe
TokenLParen
TokenRParen
TokenLBracket
TokenRBracket
TokenAsterisk
TokenPlus
TokenQuestion
TokenTerminal
TokenCharacterClass
TokenNonTerminal
)

type Token struct {
Type TokenType
Value string
}

func isDigit(c rune) bool {
return c >= '0' && c <= '9'
}

func isValidIdCharacter(c rune) bool {
return c == '-' || c == '_' || (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9')
}

func isValidRuleName(name string) (bool, error) {
if len(name) == 0 {
return false, fmt.Errorf("empty rule name")
}

// can have A-Z, a-z, 0-9, -, _ and cannot start with number
for i, c := range name {
if i == 0 {
if isDigit(c) {
return false, fmt.Errorf("rule name cannot start with a number '%s'", name)
}
}
if !isValidIdCharacter(c) {
return false, fmt.Errorf("invalid character '%c' in rule name '%s'", c, name)
}
}

return true, nil
}

func parseRule(rule string) ([]Token, error) {
var tokens []Token
rule = strings.TrimSpace(rule)
i := 0
n := len(rule)

for i < n {
switch {
case rule[i] == ' ':
// Skip spaces
i++
case rule[i] == '|':
tokens = append(tokens, Token{Type: TokenPipe, Value: string(rule[i])})
i++
case rule[i] == '(':
tokens = append(tokens, Token{Type: TokenLParen, Value: string(rule[i])})
i++
case rule[i] == ')':
tokens = append(tokens, Token{Type: TokenRParen, Value: string(rule[i])})
i++
case rule[i] == '[':
// Character class
start := i
i++
for i < n {
if rule[i] == ']' && rule[i-1] != '\\' {
i++
break
}
i++
}
if i <= n {
tokens = append(tokens, Token{Type: TokenCharacterClass, Value: rule[start:i]})
} else {
return nil, fmt.Errorf("unclosed character class")
}
case rule[i] == ']':
tokens = append(tokens, Token{Type: TokenRBracket, Value: string(rule[i])})
i++
case rule[i] == '*':
tokens = append(tokens, Token{Type: TokenAsterisk, Value: string(rule[i])})
i++
case rule[i] == '+':
tokens = append(tokens, Token{Type: TokenPlus, Value: string(rule[i])})
i++
case rule[i] == '?':
tokens = append(tokens, Token{Type: TokenQuestion, Value: string(rule[i])})
i++
case rule[i] == '"':
// Terminal sequence with escaped quotes handling
start := i
i++
for i < n {
if rule[i] == '\\' && i+1 < n {
i += 2
} else if rule[i] == '"' {
i++
break
} else {
i++
}
}
tokens = append(tokens, Token{Type: TokenTerminal, Value: rule[start:i]})
default:
// Non-terminal or literal character
start := i
for i < n && !strings.ContainsRune(" |()[]*+?", rune(rule[i])) {
i++
}
// Check if it is a valid rule name or character class
if isValid, err := isValidRuleName(rule[start:i]); !isValid {
return nil, err
}
tokens = append(tokens, Token{Type: TokenNonTerminal, Value: rule[start:i]})
}
}

return tokens, nil
}

func validateRule(tokens []Token, definedRules map[string]bool) error {
if len(tokens) == 0 {
return fmt.Errorf("empty rule")
}

// Stack to track the balance of parentheses and brackets
var stack []TokenType

for i, token := range tokens {
switch token.Type {
case TokenLParen:
stack = append(stack, TokenLParen)
case TokenRParen:
if len(stack) == 0 || stack[len(stack)-1] != TokenLParen {
return fmt.Errorf("unmatched closing parenthesis")
}
stack = stack[:len(stack)-1]
case TokenLBracket:
stack = append(stack, TokenLBracket)
case TokenRBracket:
if len(stack) == 0 || stack[len(stack)-1] != TokenLBracket {
return fmt.Errorf("unmatched closing bracket")
}
stack = stack[:len(stack)-1]
case TokenAsterisk, TokenPlus, TokenQuestion:
// These tokens should follow either a terminal, non-terminal, or character class
if i == 0 {
return fmt.Errorf("unexpected %s at the beginning of the rule", token.Value)
}
prevToken := tokens[i-1]
if prevToken.Type != TokenTerminal && prevToken.Type != TokenNonTerminal && prevToken.Type != TokenCharacterClass && prevToken.Type != TokenRParen && prevToken.Type != TokenRBracket {
return fmt.Errorf("unexpected %s following %s", token.Value, prevToken.Value)
}
case TokenPipe:
// Pipe must be within a sequence, not at the beginning or end
if i == 0 || i == len(tokens)-1 {
return fmt.Errorf("unexpected pipe '|' at the beginning or end of the rule")
}
case TokenNonTerminal:
// Check if the non-terminal is defined
if !definedRules[token.Value] {
return fmt.Errorf("undefined rule: %s", token.Value)
}
}
}

if len(stack) > 0 {
return fmt.Errorf("unclosed parentheses or brackets")
}

return nil
}

func parseGrammar(grammar string) (map[string]([]Token), error) {
cleanedInput := removeComments(grammar)
rules, err := breakIntoArrayOfRules(cleanedInput)
if err != nil {
return nil, err
}

ruleTokens := make(map[string]([]Token))

for key, value := range rules {
tokens, err := parseRule(value)
if err != nil {
return nil, fmt.Errorf("error parsing rule \"%s\": %v", key, err)
}
ruleTokens[key] = tokens
}

return ruleTokens, nil
}

func ValidateGrammar(grammar string) error {
// Since GBNF is essentially just a list of rules, we can validate the grammar by
// removing all comments, removing all white spice
// and then breaking the input into an array of rules
// and validating each rule one by one. This will hopefully
// kee the function simple and easy to understand
// while still being able to validate the grammar
// with some help to tell where something is wrong

ruleTokens, err := parseGrammar(grammar)
if err != nil {
return err
}

definedRules := make(map[string]bool)
for key := range ruleTokens {
definedRules[key] = true
}

for key, value := range ruleTokens {
if err := validateRule(value, definedRules); err != nil {
return fmt.Errorf("error in rule \"%s\": %v", key, err)
}
}
return nil
}
Loading

0 comments on commit 24adac4

Please sign in to comment.