From 4df89231ed8de973237b11fe5ee21fa25cd2e420 Mon Sep 17 00:00:00 2001 From: WillAbides <233500+WillAbides@users.noreply.github.com> Date: Thu, 14 Sep 2023 21:33:31 -0500 Subject: [PATCH] Fix quoted comments (#370) * Make path filtering work with quotes --- parser/context.go | 7 ---- parser/parser_test.go | 4 ++ path.go | 26 +++++++++++-- path_test.go | 7 ++++ yaml_test.go | 89 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 123 insertions(+), 10 deletions(-) diff --git a/parser/context.go b/parser/context.go index 99f18b1..42cc4f8 100644 --- a/parser/context.go +++ b/parser/context.go @@ -13,7 +13,6 @@ type context struct { idx int size int tokens token.Tokens - mode Mode path string } @@ -56,7 +55,6 @@ func (c *context) copy() *context { idx: c.idx, size: c.size, tokens: append(token.Tokens{}, c.tokens...), - mode: c.mode, path: c.path, } } @@ -145,10 +143,6 @@ func (c *context) afterNextNotCommentToken() *token.Token { return nil } -func (c *context) enabledComment() bool { - return c.mode&ParseComments != 0 -} - func (c *context) isCurrentCommentToken() bool { tk := c.currentToken() if tk == nil { @@ -193,7 +187,6 @@ func newContext(tokens token.Tokens, mode Mode) *context { idx: 0, size: len(filteredTokens), tokens: token.Tokens(filteredTokens), - mode: mode, path: "$", } } diff --git a/parser/parser_test.go b/parser/parser_test.go index d24b262..13157c0 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -892,6 +892,8 @@ a: # commentA i: fuga # commentI j: piyo # commentJ k.l.m.n: moge # commentKLMN +o#p: hogera # commentOP +q#.r: hogehoge # commentQR ` f, err := parser.ParseBytes([]byte(yml), parser.ParseComments) if err != nil { @@ -922,6 +924,8 @@ k.l.m.n: moge # commentKLMN "$.a.i", "$.j", "$.'k.l.m.n'", + "$.o#p", + "$.'q#.r'", } if !reflect.DeepEqual(expectedPaths, commentPaths) { t.Fatalf("failed to get YAMLPath to the comment node:\nexpected[%s]\ngot [%s]", expectedPaths, commentPaths) diff --git a/path.go b/path.go index 7a0c3b1..72554bd 100644 --- a/path.go +++ b/path.go @@ -500,11 +500,29 @@ func newSelectorNode(selector string) *selectorNode { } func (n *selectorNode) filter(node ast.Node) (ast.Node, error) { + selector := n.selector + if len(selector) > 1 && selector[0] == '\'' && selector[len(selector)-1] == '\'' { + selector = selector[1 : len(selector)-1] + } switch node.Type() { case ast.MappingType: for _, value := range node.(*ast.MappingNode).Values { key := value.Key.GetToken().Value - if key == n.selector { + if len(key) > 0 { + switch key[0] { + case '"': + var err error + key, err = strconv.Unquote(key) + if err != nil { + return nil, errors.Wrapf(err, "failed to unquote") + } + case '\'': + if len(key) > 1 && key[len(key)-1] == '\'' { + key = key[1 : len(key)-1] + } + } + } + if key == selector { if n.child == nil { return value.Value, nil } @@ -518,7 +536,7 @@ func (n *selectorNode) filter(node ast.Node) (ast.Node, error) { case ast.MappingValueType: value := node.(*ast.MappingValueNode) key := value.Key.GetToken().Value - if key == n.selector { + if key == selector { if n.child == nil { return value.Value, nil } @@ -571,7 +589,9 @@ func (n *selectorNode) replace(node ast.Node, target ast.Node) error { } func (n *selectorNode) String() string { - s := fmt.Sprintf(".%s", n.selector) + var builder PathBuilder + selector := builder.normalizeSelectorName(n.selector) + s := fmt.Sprintf(".%s", selector) if n.child != nil { s += n.child.String() } diff --git a/path_test.go b/path_test.go index 4271da5..c0073ce 100644 --- a/path_test.go +++ b/path_test.go @@ -61,6 +61,8 @@ store: bicycle: color: red price: 19.95 + bicycle*unicycle: + price: 20.25 ` tests := []struct { name string @@ -97,6 +99,11 @@ store: path: builder().Root().Child("store").Child("bicycle").Child("price").Build(), expected: float64(19.95), }, + { + name: `$.store.'bicycle*unicycle'.price`, + path: builder().Root().Child("store").Child(`bicycle*unicycle`).Child("price").Build(), + expected: float64(20.25), + }, } t.Run("PathString", func(t *testing.T) { for _, test := range tests { diff --git a/yaml_test.go b/yaml_test.go index 5828629..4446d31 100644 --- a/yaml_test.go +++ b/yaml_test.go @@ -2,6 +2,7 @@ package yaml_test import ( "bytes" + "strings" "testing" "github.com/google/go-cmp/cmp" @@ -1161,6 +1162,94 @@ hoge: }) } +func TestCommentMapRoundTrip(t *testing.T) { + // test that an unmarshal and marshal round trip retains comments. + // if expect is empty, the test will use the input as the expected result. + tests := []struct { + name string + source string + expect string + encodeOptions []yaml.EncodeOption + }{ + { + name: "simple map", + source: ` +# head +a: 1 # line +# foot +`, + }, + { + name: "nesting", + source: ` +- 1 # one +- foo: + a: b + # c comment + c: d # d comment + "e#f": g # g comment + h.i: j # j comment + "k.#l": m # m comment +`, + }, + { + name: "single quotes", + source: `'a#b': c # c comment`, + encodeOptions: []yaml.EncodeOption{yaml.UseSingleQuote(true)}, + }, + { + name: "single quotes added in encode", + source: `a#b: c # c comment`, + encodeOptions: []yaml.EncodeOption{yaml.UseSingleQuote(true)}, + expect: `'a#b': c # c comment`, + }, + { + name: "double quotes quotes transformed to single quotes", + source: `"a#b": c # c comment`, + encodeOptions: []yaml.EncodeOption{yaml.UseSingleQuote(true)}, + expect: `'a#b': c # c comment`, + }, + { + name: "single quotes quotes transformed to double quotes", + source: `'a#b': c # c comment`, + expect: `"a#b": c # c comment`, + }, + { + name: "single quotes removed", + source: `'a': b # b comment`, + expect: `a: b # b comment`, + }, + { + name: "double quotes removed", + source: `"a": b # b comment`, + expect: `a: b # b comment`, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var val any + cm := yaml.CommentMap{} + source := strings.TrimSpace(test.source) + if err := yaml.UnmarshalWithOptions([]byte(source), &val, yaml.CommentToMap(cm)); err != nil { + t.Fatalf("%+v", err) + } + marshaled, err := yaml.MarshalWithOptions(val, append(test.encodeOptions, yaml.WithComment(cm))...) + if err != nil { + t.Fatalf("%+v", err) + } + got := strings.TrimSpace(string(marshaled)) + expect := strings.TrimSpace(test.expect) + if expect == "" { + expect = source + } + if got != expect { + t.Fatalf("expected:\n%s\ngot:\n%s\n", expect, got) + } + }) + + } +} + func TestRegisterCustomMarshaler(t *testing.T) { type T struct { Foo []byte `yaml:"foo"`