diff --git a/conf/lex.go b/conf/lex.go index f3050331bc..5da8a73243 100644 --- a/conf/lex.go +++ b/conf/lex.go @@ -343,7 +343,10 @@ func lexKeyStart(lx *lexer) stateFn { func lexDubQuotedKey(lx *lexer) stateFn { r := lx.peek() if r == dqStringEnd { - lx.emit(itemKey) + include := lx.keyCheckKeyword(nil, nil) + if include != nil { + return include + } lx.next() return lexSkip(lx, lexKeyEnd) } else if r == eof { @@ -361,7 +364,10 @@ func lexDubQuotedKey(lx *lexer) stateFn { func lexQuotedKey(lx *lexer) stateFn { r := lx.peek() if r == sqStringEnd { - lx.emit(itemKey) + include := lx.keyCheckKeyword(nil, nil) + if include != nil { + return include + } lx.next() return lexSkip(lx, lexKeyEnd) } else if r == eof { @@ -394,7 +400,7 @@ func (lx *lexer) keyCheckKeyword(fallThrough, push stateFn) stateFn { // lexIncludeStart will consume the whitespace til the start of the value. func lexIncludeStart(lx *lexer) stateFn { r := lx.next() - if isWhitespace(r) { + if isWhitespace(r) || r == dqStringEnd || r == sqStringEnd || isKeySeparator(r) { return lexSkip(lx, lexIncludeStart) } lx.backup() @@ -441,7 +447,7 @@ func lexIncludeString(lx *lexer) stateFn { lx.backup() lx.emit(itemInclude) return lx.pop() - case r == sqStringEnd: + case r == sqStringEnd || r == dqStringEnd: lx.backup() lx.emit(itemInclude) lx.next() @@ -668,7 +674,10 @@ func lexMapKeyStart(lx *lexer) stateFn { func lexMapQuotedKey(lx *lexer) stateFn { r := lx.peek() if r == sqStringEnd { - lx.emit(itemKey) + include := lx.keyCheckKeyword(nil, lexMapValueEnd) + if include != nil { + return include + } lx.next() return lexSkip(lx, lexMapKeyEnd) } @@ -680,7 +689,10 @@ func lexMapQuotedKey(lx *lexer) stateFn { func lexMapDubQuotedKey(lx *lexer) stateFn { r := lx.peek() if r == dqStringEnd { - lx.emit(itemKey) + include := lx.keyCheckKeyword(nil, lexMapValueEnd) + if include != nil { + return include + } lx.next() return lexSkip(lx, lexMapKeyEnd) } diff --git a/conf/parse_test.go b/conf/parse_test.go index 64e4c6f580..0d5ebe2205 100644 --- a/conf/parse_test.go +++ b/conf/parse_test.go @@ -2,7 +2,9 @@ package conf import ( "fmt" + "io/ioutil" "os" + "path/filepath" "reflect" "strings" "testing" @@ -328,3 +330,97 @@ func TestIncludeVariablesWithChecks(t *testing.T) { expectKeyVal(t, m, "BOB_PASS", "$2a$11$dZM98SpGeI7dCFFGSpt.JObQcix8YHml4TBUZoge9R1uxnMIln5ly", 3, 1) expectKeyVal(t, m, "CAROL_PASS", "foo", 6, 3) } + +func TestIncludesJSONSyntax(t *testing.T) { + tmpdir, err := ioutil.TempDir("", "includes") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpdir) + + content := ` + { + "debug": true, + "include": "simple.json", + "cluster": { + "include": "cluster.json" + } + } + ` + configFile := filepath.Join(tmpdir, "multiple.json") + if err := ioutil.WriteFile(configFile, []byte(content), 0666); err != nil { + t.Fatal(err) + } + + content = ` + { + "listen": "127.0.0.1:4223" + } + ` + include := filepath.Join(tmpdir, "simple.json") + if err := ioutil.WriteFile(include, []byte(content), 0666); err != nil { + t.Fatal(err) + } + + content = ` + { + "routes": [ + "nats://nats-A:6222", + "nats://nats-C:6222", + "nats://nats-B:6222" + ] + } + ` + include = filepath.Join(tmpdir, "cluster.json") + if err := ioutil.WriteFile(include, []byte(content), 0666); err != nil { + t.Fatal(err) + } + + m, err := ParseFile(configFile) + if err != nil { + t.Fatalf("Received err: %v\n", err) + } + if m == nil { + t.Fatal("Received nil map") + } + + ex := map[string]interface{}{ + "listen": "127.0.0.1:4223", + "debug": true, + "cluster": map[string]interface{}{ + "routes": []interface{}{ + "nats://nats-A:6222", + "nats://nats-C:6222", + "nats://nats-B:6222", + }, + }, + } + if !reflect.DeepEqual(m, ex) { + t.Fatalf("Not Equal:\nReceived: '%+v'\nExpected: '%+v'\n", m, ex) + } + + content = ` + 'debug': true + + 'include': 'simple.json' + + 'cluster': { + 'include': 'cluster.json' + } + ` + configFile = filepath.Join(tmpdir, "multiple.json") + if err := ioutil.WriteFile(configFile, []byte(content), 0666); err != nil { + t.Fatal(err) + } + + m2, err := ParseFile(configFile) + if err != nil { + t.Fatalf("Received err: %v\n", err) + } + if m2 == nil { + t.Fatal("Received nil map") + } + if !reflect.DeepEqual(m2, ex) { + t.Fatalf("Not Equal:\nReceived: '%+v'\nExpected: '%+v'\n", m2, ex) + } +}