Skip to content

Commit

Permalink
Prototype to add support for a two-arg pggen.arg() directive to scann…
Browse files Browse the repository at this point in the history
…er/parser
  • Loading branch information
N. Ben Cohen committed Jun 18, 2023
1 parent 55d8fc9 commit fece1fe
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 30 deletions.
52 changes: 28 additions & 24 deletions internal/parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ package parser

import (
"fmt"
"github.com/jschaf/pggen/internal/ast"
"github.com/jschaf/pggen/internal/scanner"
"github.com/jschaf/pggen/internal/token"
goscan "go/scanner"
gotok "go/token"
"regexp"
"strconv"
"strings"

"github.com/jschaf/pggen/internal/ast"
"github.com/jschaf/pggen/internal/scanner"
"github.com/jschaf/pggen/internal/token"
)

type parser struct {
Expand Down Expand Up @@ -219,9 +220,7 @@ func (p *parser) parseQuery() ast.Query {
p.error(p.pos, "unterminated query (no semicolon): "+string(p.src[pos:p.pos]))
return &ast.BadQuery{From: pos, To: p.pos}
}
hasPggenArg := strings.HasSuffix(p.lit, "pggen.arg(") ||
strings.HasSuffix(p.lit, "pggen.arg (")
if p.tok == token.QueryFragment && hasPggenArg {
if p.tok == token.Directive {
arg, ok := p.parsePggenArg()
if !ok {
return &ast.BadQuery{From: pos, To: p.pos}
Expand Down Expand Up @@ -329,35 +328,40 @@ func validateProtoMsgType(val string) (string, error) {

// argPos is the name and position of expression like pggen.arg('foo').
type argPos struct {
lo, hi int
name string
lo, hi int
name string
defaultValueExpression string
}

// parsePggenArg parses the name from: pggen.arg('foo') and pos for the start
// and end.
func (p *parser) parsePggenArg() (argPos, bool) {
lo := int(p.pos) + strings.LastIndex(p.lit, "pggen") - 1
p.next() // consume query fragment that contains "pggen.arg("
if p.tok != token.String {
p.error(p.pos, `expected string literal after "pggen.arg("`)
return argPos{}, false
hi := int(p.pos) + len(p.lit) - 1

var nameStart int
nameEnd := -1

nameStart = strings.Index(p.lit, "'")
if nameStart != -1 {
nameEnd = (nameStart + 1) + strings.Index(p.lit[nameStart+1:], "'")
}
if len(p.lit) < 3 || p.lit[0] != '\'' || p.lit[len(p.lit)-1] != '\'' {

if nameStart == -1 || nameEnd == -1 || nameEnd-nameStart < 3 || p.lit[nameStart] != '\'' || p.lit[nameEnd] != '\'' {
p.error(p.pos, `expected single-quoted string literal after "pggen.arg("`)
return argPos{}, false
}
name := p.lit[1 : len(p.lit)-1]
p.next() // consume string literal
if p.tok != token.QueryFragment {
p.error(p.pos, `expected query fragment after parsing pggen.arg string`)
return argPos{}, false
}
if !strings.HasPrefix(p.lit, ")") {
p.error(p.pos, `expected closing paren ")" after parsing pggen.arg string`)
return argPos{}, false
name := p.lit[nameStart+1 : nameEnd]

firstComma := strings.Index(p.lit, ",")

defaultValueExpression := ""
if firstComma != -1 {
defaultValueExpression = p.lit[firstComma+1 : len(p.lit)-1]
}
hi := int(p.pos)
return argPos{lo: lo, hi: hi, name: name}, true

p.next()
return argPos{lo: lo, hi: hi, name: name, defaultValueExpression: defaultValueExpression}, true
}

// prepareSQL replaces each pggen.arg with the $n, respecting the order that the
Expand Down
89 changes: 88 additions & 1 deletion internal/scanner/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ package scanner
import (
"bytes"
"fmt"
"github.com/jschaf/pggen/internal/token"
gotok "go/token"
"unicode"
"unicode/utf8"

"github.com/jschaf/pggen/internal/token"
)

const (
Expand Down Expand Up @@ -132,6 +133,22 @@ func (s *Scanner) peek() byte {
return 0
}

// peekDirective looks ahead from most recently read character without
// advancing the scanner. If the lookahead finds patterns matching the `pggen.arg(` directive, then peekDirective returns true and otherwise false
func (s *Scanner) peekDirective() bool {
patterns := []string{"pggen.arg(", "pggen.arg ("}
for _, pattern := range patterns {
n := len(pattern)
if s.offset+n < len(s.src) {
nextString := string(s.src[s.offset : s.offset+n])
if nextString == pattern {
return true
}
}
}
return false
}

func (s *Scanner) skipWhitespace() {
for isSpace(s.ch) {
s.next()
Expand Down Expand Up @@ -265,6 +282,11 @@ func (s *Scanner) scanDoubleQuoteString() (token.Token, string) {
func (s *Scanner) scanQueryFragment() (token.Token, string) {
offs := s.offset
for s.ch > 0 {

if s.peekDirective() {
return token.QueryFragment, string(s.src[offs:s.offset])
}

switch {
case s.ch == eof:
str := string(s.src[offs:s.offset])
Expand Down Expand Up @@ -295,6 +317,65 @@ func (s *Scanner) scanQueryFragment() (token.Token, string) {
return token.QueryFragment, string(s.src[offs:s.offset])
}

// scanDirective consumes pggen.arg('one_arg') or pggen.arg('one_arg', default_value_expression)
func (s *Scanner) scanDirective() (token.Token, string) {
offs := s.offset
openParenCount := 0
for s.ch > 0 {
switch {
case s.ch == eof:
str := string(s.src[offs:s.offset])
s.error(offs, "illegal pggen.arg() expression: "+str)
return token.Illegal, str
case s.ch == '-' && s.peek() == '-':
s.scanLineComment()
continue
case s.ch == '/' && s.peek() == '*':
s.scanBlockComment()
continue
case s.ch == '\'':
s.scanSingleQuoteString()
continue
case s.ch == '"':
s.scanDoubleQuoteString()
continue
case s.ch == '$':
// A dollar sign can be part of an identifier. Consume the identifier
// here for cases like 'select 1 as foo$$$$bar'.
if isLetter(s.prevCh) || isDecimal(s.prevCh) {
for isLetter(s.ch) || isDecimal(s.ch) || s.ch == '$' {
s.next()
}
continue
} else {
s.scanDollarQuoteString()
continue
}
case s.ch == '(':
openParenCount += 1
case s.ch == ')':
openParenCount -= 1
if openParenCount == 0 {
s.next()
return token.Directive, string(s.src[offs:s.offset])
}
if openParenCount < 0 {
str := string(s.src[offs:s.offset])
s.error(offs, "illegal pggen.arg() expression: "+str)
return token.Illegal, str
}
case s.ch == ';':
str := string(s.src[offs:s.offset])
s.error(offs, "illegal pggen.arg() expression: "+str)
return token.Illegal, str
}
s.next()
}
str := string(s.src[offs:s.offset])
s.error(offs, "illegal pggen.arg() expression: "+str)
return token.Illegal, str
}

// Scan scans the next token and returns the token position, the token, and its
// literal string if applicable. The source end is indicated by token.EOF.
//
Expand All @@ -317,6 +398,12 @@ func (s *Scanner) Scan() (pos gotok.Pos, tok token.Token, lit string) {
s.skipWhitespace()
pos = s.file.Pos(s.offset)

if s.peekDirective() {
tok, lit = s.scanDirective()
s.prev = tok
return
}

switch s.ch {
case eof:
tok = token.EOF
Expand Down
18 changes: 13 additions & 5 deletions internal/scanner/scanner_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package scanner

import (
"github.com/jschaf/pggen/internal/token"
"github.com/stretchr/testify/assert"
gotok "go/token"
"testing"

"github.com/jschaf/pggen/internal/token"
"github.com/stretchr/testify/assert"
)

func newlineCount(s string) int {
Expand Down Expand Up @@ -49,9 +50,12 @@ func (ec *errorCollector) asHandler() ErrorHandler {
}
}

func frag(lit string) stringTok { return stringTok{t: token.QueryFragment, lit: lit, raw: lit} }
func str(lit string) stringTok { return stringTok{t: token.String, lit: lit, raw: lit} }
func ident(ident string) stringTok { return stringTok{t: token.QuotedIdent, lit: ident, raw: ident} }
func frag(lit string) stringTok { return stringTok{t: token.QueryFragment, lit: lit, raw: lit} }
func lineComment(lit string) stringTok { return stringTok{t: token.LineComment, lit: lit, raw: lit} }
func blockComment(lit string) stringTok { return stringTok{t: token.BlockComment, lit: lit, raw: lit} }
func str(lit string) stringTok { return stringTok{t: token.String, lit: lit, raw: lit} }
func ident(ident string) stringTok { return stringTok{t: token.QuotedIdent, lit: ident, raw: ident} }
func directive(lit string) stringTok { return stringTok{t: token.Directive, lit: lit, raw: lit} }

func TestScanner_Scan(t *testing.T) {
type testCase struct {
Expand All @@ -71,6 +75,10 @@ func TestScanner_Scan(t *testing.T) {
{"/* abc */", []stringTok{{t: token.BlockComment, lit: "/* abc */"}}, nil},
{"/* /* abc */ */", []stringTok{{t: token.BlockComment, lit: "/* /* abc */ */"}}, nil},
{"SELECT 1", []stringTok{frag("SELECT 1")}, nil},
{"SELECT pggen.arg('arg1')", []stringTok{frag("SELECT "), directive("pggen.arg('arg1')")}, nil},
{"SELECT pggen.arg('arg2', null::int)", []stringTok{frag("SELECT "), directive("pggen.arg('arg2', null::int)")}, nil},
{"SELECT pggen.arg('arg2', exists(SELECT 1 FROM bar))", []stringTok{frag("SELECT "), directive("pggen.arg('arg2', exists(SELECT 1 FROM bar))")}, nil},
{"SELECT pggen.arg('arg2', exists(SELECT '}'\n-- test comment }\n/* test comment }*/ FROM bar))", []stringTok{frag("SELECT "), directive("pggen.arg('arg2', exists(SELECT '}'\n-- test comment }\n/* test comment }*/ FROM bar))")}, nil},
{"SELECT abc$", []stringTok{frag("SELECT abc$")}, nil},
{"SELECT a$$bc", []stringTok{frag("SELECT a$$bc")}, nil},
{"SELECT a$$$bc", []stringTok{frag("SELECT a$$$bc")}, nil},
Expand Down
3 changes: 3 additions & 0 deletions internal/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ const (
QuotedIdent // "foo_bar""baz"
QueryFragment // anything else
Semicolon // semicolon ending a query
Directive // default value supplied to two argument version of ppgen.arg()
)

func (t Token) String() string {
Expand All @@ -35,6 +36,8 @@ func (t Token) String() string {
return "QueryFragment"
case Semicolon:
return "Semicolon"
case Directive:
return "Directive"
default:
panic("unhandled token.String(): " + strconv.Itoa(int(t)))
}
Expand Down

0 comments on commit fece1fe

Please sign in to comment.