-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Exposing grammar as a request parameter in completion/chat with go-si…
…de grammar validation
- Loading branch information
1 parent
105186a
commit 1dd5575
Showing
6 changed files
with
770 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,318 @@ | ||
package llm | ||
|
||
import ( | ||
"bufio" | ||
"fmt" | ||
"strings" | ||
) | ||
|
||
// removeComments removes comments from the input string. Comments are lines | ||
// that start with # (with optional leading spaces) or # that appear at the end of a line. | ||
func removeComments(input string) string { | ||
var output strings.Builder | ||
scanner := bufio.NewScanner(strings.NewReader(input)) | ||
|
||
for scanner.Scan() { | ||
line := scanner.Text() | ||
|
||
// Remove comments at the beginning of the line | ||
line = strings.TrimSpace(line) | ||
if strings.HasPrefix(line, "#") { | ||
continue | ||
} | ||
|
||
// Remove comments at the end of the line | ||
if idx := strings.Index(line, "#"); idx != -1 { | ||
line = line[:idx] | ||
} | ||
|
||
// Trim any trailing spaces from the cleaned line | ||
line = strings.TrimSpace(line) | ||
|
||
// Append the clean line to the output if it's not empty | ||
if line != "" { | ||
output.WriteString(line) | ||
output.WriteString("\n") | ||
} | ||
} | ||
|
||
if err := scanner.Err(); err != nil { | ||
fmt.Println("Error reading input:", err) | ||
} | ||
|
||
return output.String() | ||
} | ||
|
||
func breakIntoArrayOfRules(input string) map[string]string { | ||
var rules []string | ||
scanner := bufio.NewScanner(strings.NewReader(input)) | ||
|
||
var currentRule strings.Builder | ||
for scanner.Scan() { | ||
line := scanner.Text() | ||
|
||
// if the line matches the pattern, we start a new rule | ||
if strings.Contains(line, "::=") { | ||
if currentRule.Len() > 0 { | ||
rules = append(rules, currentRule.String()) | ||
currentRule.Reset() | ||
} | ||
} | ||
|
||
// append the line to the current rule | ||
currentRule.WriteString(line) | ||
currentRule.WriteString("\n") | ||
} | ||
|
||
if currentRule.Len() > 0 { | ||
rules = append(rules, currentRule.String()) | ||
} | ||
|
||
// remove all new lines from all rules | ||
for i, rule := range rules { | ||
rules[i] = strings.ReplaceAll(rule, "\n", " ") | ||
} | ||
|
||
// put rules into a map | ||
rulesMap := make(map[string]string) | ||
|
||
for _, rule := range rules { | ||
// split the rule into the key and value | ||
parts := strings.Split(rule, "::=") | ||
key := strings.TrimSpace(parts[0]) | ||
value := strings.TrimSpace(parts[1]) | ||
|
||
// add the key and value to the map | ||
rulesMap[key] = value | ||
} | ||
|
||
return rulesMap | ||
} | ||
|
||
type TokenType int | ||
|
||
const ( | ||
TokenUnknown TokenType = iota | ||
TokenSpace | ||
TokenPipe | ||
TokenLParen | ||
TokenRParen | ||
TokenLBracket | ||
TokenRBracket | ||
TokenAsterisk | ||
TokenPlus | ||
TokenQuestion | ||
TokenTerminal | ||
TokenCharacterClass | ||
TokenNonTerminal | ||
) | ||
|
||
type Token struct { | ||
Type TokenType | ||
Value string | ||
} | ||
|
||
func isDigit(c rune) bool { | ||
return c >= '0' && c <= '9' | ||
} | ||
|
||
func isValidIdCharacter(c rune) bool { | ||
return c == '-' || c == '_' || (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') | ||
} | ||
|
||
func isValidRuleName(name string) (bool, error) { | ||
if len(name) == 0 { | ||
return false, fmt.Errorf("empty rule name") | ||
} | ||
|
||
// can have A-Z, a-z, 0-9, -, _ and cannot start with number | ||
for i, c := range name { | ||
if i == 0 { | ||
if isDigit(c) { | ||
return false, fmt.Errorf("rule name cannot start with a number '%s'", name) | ||
} | ||
} | ||
if !isValidIdCharacter(c) { | ||
return false, fmt.Errorf("invalid character '%c' in rule name '%s'", c, name) | ||
} | ||
} | ||
|
||
return true, nil | ||
} | ||
|
||
func parseRule(rule string) ([]Token, error) { | ||
var tokens []Token | ||
rule = strings.TrimSpace(rule) | ||
i := 0 | ||
n := len(rule) | ||
|
||
for i < n { | ||
switch { | ||
case rule[i] == ' ': | ||
// Skip spaces | ||
i++ | ||
case rule[i] == '|': | ||
tokens = append(tokens, Token{Type: TokenPipe, Value: string(rule[i])}) | ||
i++ | ||
case rule[i] == '(': | ||
tokens = append(tokens, Token{Type: TokenLParen, Value: string(rule[i])}) | ||
i++ | ||
case rule[i] == ')': | ||
tokens = append(tokens, Token{Type: TokenRParen, Value: string(rule[i])}) | ||
i++ | ||
case rule[i] == '[': | ||
// Character class | ||
start := i | ||
i++ | ||
for i < n { | ||
if rule[i] == ']' && rule[i-1] != '\\' { | ||
i++ | ||
break | ||
} | ||
i++ | ||
} | ||
if i <= n { | ||
tokens = append(tokens, Token{Type: TokenCharacterClass, Value: rule[start:i]}) | ||
} else { | ||
return nil, fmt.Errorf("unclosed character class") | ||
} | ||
case rule[i] == ']': | ||
tokens = append(tokens, Token{Type: TokenRBracket, Value: string(rule[i])}) | ||
i++ | ||
case rule[i] == '*': | ||
tokens = append(tokens, Token{Type: TokenAsterisk, Value: string(rule[i])}) | ||
i++ | ||
case rule[i] == '+': | ||
tokens = append(tokens, Token{Type: TokenPlus, Value: string(rule[i])}) | ||
i++ | ||
case rule[i] == '?': | ||
tokens = append(tokens, Token{Type: TokenQuestion, Value: string(rule[i])}) | ||
i++ | ||
case rule[i] == '"': | ||
// Terminal sequence with escaped quotes handling | ||
start := i | ||
i++ | ||
for i < n { | ||
if rule[i] == '\\' && i+1 < n { | ||
i += 2 | ||
} else if rule[i] == '"' { | ||
i++ | ||
break | ||
} else { | ||
i++ | ||
} | ||
} | ||
tokens = append(tokens, Token{Type: TokenTerminal, Value: rule[start:i]}) | ||
default: | ||
// Non-terminal or literal character | ||
start := i | ||
for i < n && !strings.ContainsRune(" |()[]*+?", rune(rule[i])) { | ||
i++ | ||
} | ||
// Check if it is a valid rule name or character class | ||
if isValid, err := isValidRuleName(rule[start:i]); !isValid { | ||
return nil, err | ||
} | ||
tokens = append(tokens, Token{Type: TokenNonTerminal, Value: rule[start:i]}) | ||
} | ||
} | ||
|
||
return tokens, nil | ||
} | ||
|
||
func validateRule(tokens []Token, definedRules map[string]bool) error { | ||
if len(tokens) == 0 { | ||
return fmt.Errorf("empty rule") | ||
} | ||
|
||
// Stack to track the balance of parentheses and brackets | ||
var stack []TokenType | ||
|
||
for i, token := range tokens { | ||
switch token.Type { | ||
case TokenLParen: | ||
stack = append(stack, TokenLParen) | ||
case TokenRParen: | ||
if len(stack) == 0 || stack[len(stack)-1] != TokenLParen { | ||
return fmt.Errorf("unmatched closing parenthesis") | ||
} | ||
stack = stack[:len(stack)-1] | ||
case TokenLBracket: | ||
stack = append(stack, TokenLBracket) | ||
case TokenRBracket: | ||
if len(stack) == 0 || stack[len(stack)-1] != TokenLBracket { | ||
return fmt.Errorf("unmatched closing bracket") | ||
} | ||
stack = stack[:len(stack)-1] | ||
case TokenAsterisk, TokenPlus, TokenQuestion: | ||
// These tokens should follow either a terminal, non-terminal, or character class | ||
if i == 0 { | ||
return fmt.Errorf("unexpected %s at the beginning of the rule", token.Value) | ||
} | ||
prevToken := tokens[i-1] | ||
if prevToken.Type != TokenTerminal && prevToken.Type != TokenNonTerminal && prevToken.Type != TokenCharacterClass && prevToken.Type != TokenRParen && prevToken.Type != TokenRBracket { | ||
return fmt.Errorf("unexpected %s following %s", token.Value, prevToken.Value) | ||
} | ||
case TokenPipe: | ||
// Pipe must be within a sequence, not at the beginning or end | ||
if i == 0 || i == len(tokens)-1 { | ||
return fmt.Errorf("unexpected pipe '|' at the beginning or end of the rule") | ||
} | ||
case TokenNonTerminal: | ||
// Check if the non-terminal is defined | ||
if !definedRules[token.Value] { | ||
return fmt.Errorf("undefined rule: %s", token.Value) | ||
} | ||
} | ||
} | ||
|
||
if len(stack) > 0 { | ||
return fmt.Errorf("unclosed parentheses or brackets") | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func parseGrammar(grammar string) (map[string]([]Token), error) { | ||
cleanedInput := removeComments(grammar) | ||
rules := breakIntoArrayOfRules(cleanedInput) | ||
|
||
ruleTokens := make(map[string]([]Token)) | ||
|
||
for key, value := range rules { | ||
tokens, err := parseRule(value) | ||
if err != nil { | ||
return nil, fmt.Errorf("error parsing rule \"%s\": %v", key, err) | ||
} | ||
ruleTokens[key] = tokens | ||
} | ||
|
||
return ruleTokens, nil | ||
} | ||
|
||
func ValidateGrammar(grammar string) error { | ||
// Since GBNF is essentially just a list of rules, we can validate the grammar by | ||
// removing all comments, removing all white spice | ||
// and then breaking the input into an array of rules | ||
// and validating each rule one by one. This will hopefully | ||
// kee the function simple and easy to understand | ||
// while still being able to validate the grammar | ||
// with some help to tell where something is wrong | ||
|
||
ruleTokens, err := parseGrammar(grammar) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
definedRules := make(map[string]bool) | ||
for key := range ruleTokens { | ||
definedRules[key] = true | ||
} | ||
|
||
for key, value := range ruleTokens { | ||
if err := validateRule(value, definedRules); err != nil { | ||
return fmt.Errorf("error in rule \"%s\": %v", key, err) | ||
} | ||
} | ||
return nil | ||
} |
Oops, something went wrong.