Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

limit max nesting depth #206

Merged
merged 4 commits into from
Nov 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
66 changes: 37 additions & 29 deletions internal/json/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,18 @@
// if some slice of bytes is a valid beginning of a json string.
package json

import "fmt"
import (
"fmt"
)

type (
context int
scanStatus int
)

const (
contextKey context = iota
contextObj
contextArr
parseObjectKey = iota // parsing object key (before colon)
parseObjectValue // parsing object value (after colon)
parseArrayValue // parsing array value

scanContinue scanStatus = iota // uninteresting byte
scanBeginLiteral // end implied by next result != scanContinue
Expand All @@ -56,15 +57,19 @@ const (
scanSkipSpace // space byte; can skip; known to be last "continue" result
scanEnd // top-level value ended *before* this byte; known to be first "stop" result
scanError // hit an error, scanner.err.

// This limits the max nesting depth to prevent stack overflow.
// This is permitted by https://tools.ietf.org/html/rfc7159#section-9
maxNestingDepth = 10000
)

type (
scanner struct {
step func(*scanner, byte) scanStatus
contexts []context
endTop bool
err error
index int
step func(*scanner, byte) scanStatus
parseState []int
endTop bool
err error
index int
}
)

Expand Down Expand Up @@ -98,7 +103,7 @@ func isSpace(c byte) bool {

func (s *scanner) reset() {
s.step = stateBeginValue
s.contexts = s.contexts[0:0]
s.parseState = s.parseState[0:0]
s.err = nil
}

Expand All @@ -121,16 +126,21 @@ func (s *scanner) eof() scanStatus {
return scanError
}

// pushContext pushes a new parse state p onto the parse stack.
func (s *scanner) pushParseState(p context) {
s.contexts = append(s.contexts, p)
// pushParseState pushes a new parse state p onto the parse stack.
// an error state is returned if maxNestingDepth was exceeded, otherwise successState is returned.
func (s *scanner) pushParseState(c byte, newParseState int, successState scanStatus) scanStatus {
s.parseState = append(s.parseState, newParseState)
if len(s.parseState) <= maxNestingDepth {
return successState
}
return s.error(c, "exceeded max depth")
}

// popParseState pops a parse state (already obtained) off the stack
// and updates s.step accordingly.
func (s *scanner) popParseState() {
n := len(s.contexts) - 1
s.contexts = s.contexts[0:n]
n := len(s.parseState) - 1
s.parseState = s.parseState[0:n]
if n == 0 {
s.step = stateEndTop
s.endTop = true
Expand Down Expand Up @@ -158,12 +168,10 @@ func stateBeginValue(s *scanner, c byte) scanStatus {
switch c {
case '{':
s.step = stateBeginStringOrEmpty
s.pushParseState(contextKey)
return scanBeginObject
return s.pushParseState(c, parseObjectKey, scanBeginObject)
case '[':
s.step = stateBeginValueOrEmpty
s.pushParseState(contextArr)
return scanBeginArray
return s.pushParseState(c, parseArrayValue, scanBeginArray)
case '"':
s.step = stateInString
return scanBeginLiteral
Expand Down Expand Up @@ -196,8 +204,8 @@ func stateBeginStringOrEmpty(s *scanner, c byte) scanStatus {
return scanSkipSpace
}
if c == '}' {
n := len(s.contexts)
s.contexts[n-1] = contextObj
n := len(s.parseState)
s.parseState[n-1] = parseObjectValue
return stateEndValue(s, c)
}
return stateBeginString(s, c)
Expand All @@ -218,7 +226,7 @@ func stateBeginString(s *scanner, c byte) scanStatus {
// stateEndValue is the state after completing a value,
// such as after reading `{}` or `true` or `["x"`.
func stateEndValue(s *scanner, c byte) scanStatus {
n := len(s.contexts)
n := len(s.parseState)
if n == 0 {
// Completed top-level before the current byte.
s.step = stateEndTop
Expand All @@ -229,18 +237,18 @@ func stateEndValue(s *scanner, c byte) scanStatus {
s.step = stateEndValue
return scanSkipSpace
}
ps := s.contexts[n-1]
ps := s.parseState[n-1]
switch ps {
case contextKey:
case parseObjectKey:
if c == ':' {
s.contexts[n-1] = contextObj
s.parseState[n-1] = parseObjectValue
s.step = stateBeginValue
return scanObjectKey
}
return s.error(c, "after object key")
case contextObj:
case parseObjectValue:
if c == ',' {
s.contexts[n-1] = contextKey
s.parseState[n-1] = parseObjectKey
s.step = stateBeginString
return scanObjectValue
}
Expand All @@ -249,7 +257,7 @@ func stateEndValue(s *scanner, c byte) scanStatus {
return scanEndObject
}
return s.error(c, "after object key:value pair")
case contextArr:
case parseArrayValue:
if c == ',' {
s.step = stateBeginValue
return scanArrayValue
Expand Down
88 changes: 68 additions & 20 deletions internal/json/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,30 @@

package json

import "testing"

var scanTests = []struct {
data string
length int
ok bool
}{
{`foo`, 2, false},
{`}{`, 1, false},
{`{]`, 2, false},
{`{}`, 2, true},
{`{"foo":"bar"}`, 13, true},
{`{"foo":"21\t\u0009 \u1234","bar":{"baz":["qux"]}`, 48, false},
{`{"foo":"bar","bar":{"baz":["qux"]}}`, 35, true},
{`{"foo":-1,"bar":{"baz":[true, false, null, 100, 0.123]}}`, 56, true},
{`{"foo":-1,"bar":{"baz":[tru]}}`, 28, false},
{`{"foo":-1,"bar":{"baz":[nul]}}`, 28, false},
{`{"foo":-1,"bar":{"baz":[314e+1]}}`, 33, true},
}
import (
"strings"
"testing"
)

func TestScan(t *testing.T) {
for _, st := range scanTests {
tCases := []struct {
data string
length int
ok bool
}{
{`foo`, 2, false},
{`}{`, 1, false},
{`{]`, 2, false},
{`{}`, 2, true},
{`{"foo":"bar"}`, 13, true},
{`{"foo":"21\t\u0009 \u1234","bar":{"baz":["qux"]}`, 48, false},
{`{"foo":"bar","bar":{"baz":["qux"]}}`, 35, true},
{`{"foo":-1,"bar":{"baz":[true, false, null, 100, 0.123]}}`, 56, true},
{`{"foo":-1,"bar":{"baz":[tru]}}`, 28, false},
{`{"foo":-1,"bar":{"baz":[nul]}}`, 28, false},
{`{"foo":-1,"bar":{"baz":[314e+1]}}`, 33, true},
}
for _, st := range tCases {
scanned, err := Scan([]byte(st.data))
if scanned != st.length {
t.Errorf("Scan length error: expected: %d; got: %d; input: %s",
Expand All @@ -41,3 +43,49 @@ func TestScan(t *testing.T) {
}
}
}

func TestScannerMaxDepth(t *testing.T) {
tCases := []struct {
name string
data string
errMaxDepth bool
}{
{
name: "ArrayUnderMaxNestingDepth",
data: `{"a":` + strings.Repeat(`[`, 10000-1) + strings.Repeat(`]`, 10000-1) + `}`,
errMaxDepth: false,
},
{
name: "ArrayOverMaxNestingDepth",
data: `{"a":` + strings.Repeat(`[`, 10000) + strings.Repeat(`]`, 10000) + `}`,
errMaxDepth: true,
},
{
name: "ObjectUnderMaxNestingDepth",
data: `{"a":` + strings.Repeat(`{"a":`, 10000-1) + `0` + strings.Repeat(`}`, 10000-1) + `}`,
errMaxDepth: false,
},
{
name: "ObjectOverMaxNestingDepth",
data: `{"a":` + strings.Repeat(`{"a":`, 10000) + `0` + strings.Repeat(`}`, 10000) + `}`,
errMaxDepth: true,
},
}

for _, tt := range tCases {
t.Run(tt.name, func(t *testing.T) {
_, err := Scan([]byte(tt.data))
if !tt.errMaxDepth {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
} else {
if err == nil {
t.Errorf("expected error containing 'exceeded max depth', got none")
} else if !strings.Contains(err.Error(), "exceeded max depth") {
t.Errorf("expected error containing 'exceeded max depth', got: %v", err)
}
}
})
}
}