Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix quoted comments #370

Merged
merged 10 commits into from
Sep 15, 2023
Merged
7 changes: 0 additions & 7 deletions parser/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ type context struct {
idx int
size int
tokens token.Tokens
mode Mode
path string
}

Expand Down Expand Up @@ -56,7 +55,6 @@ func (c *context) copy() *context {
idx: c.idx,
size: c.size,
tokens: append(token.Tokens{}, c.tokens...),
mode: c.mode,
path: c.path,
}
}
Expand Down Expand Up @@ -145,10 +143,6 @@ func (c *context) afterNextNotCommentToken() *token.Token {
return nil
}

func (c *context) enabledComment() bool {
return c.mode&ParseComments != 0
}

func (c *context) isCurrentCommentToken() bool {
tk := c.currentToken()
if tk == nil {
Expand Down Expand Up @@ -193,7 +187,6 @@ func newContext(tokens token.Tokens, mode Mode) *context {
idx: 0,
size: len(filteredTokens),
tokens: token.Tokens(filteredTokens),
mode: mode,
path: "$",
}
}
4 changes: 4 additions & 0 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,8 @@ a: # commentA
i: fuga # commentI
j: piyo # commentJ
k.l.m.n: moge # commentKLMN
o#p: hogera # commentOP
q#.r: hogehoge # commentQR
`
f, err := parser.ParseBytes([]byte(yml), parser.ParseComments)
if err != nil {
Expand Down Expand Up @@ -854,6 +856,8 @@ k.l.m.n: moge # commentKLMN
"$.a.i",
"$.j",
"$.'k.l.m.n'",
"$.o#p",
"$.'q#.r'",
}
if !reflect.DeepEqual(expectedPaths, commentPaths) {
t.Fatalf("failed to get YAMLPath to the comment node:\nexpected[%s]\ngot [%s]", expectedPaths, commentPaths)
Expand Down
26 changes: 23 additions & 3 deletions path.go
Original file line number Diff line number Diff line change
Expand Up @@ -500,11 +500,29 @@ func newSelectorNode(selector string) *selectorNode {
}

func (n *selectorNode) filter(node ast.Node) (ast.Node, error) {
selector := n.selector
if len(selector) > 1 && selector[0] == '\'' && selector[len(selector)-1] == '\'' {
selector = selector[1 : len(selector)-1]
WillAbides marked this conversation as resolved.
Show resolved Hide resolved
}
switch node.Type() {
case ast.MappingType:
for _, value := range node.(*ast.MappingNode).Values {
key := value.Key.GetToken().Value
if key == n.selector {
if len(key) > 0 {
switch key[0] {
case '"':
var err error
key, err = strconv.Unquote(key)
if err != nil {
return nil, errors.Wrapf(err, "failed to unquote")
}
case '\'':
if len(key) > 1 && key[len(key)-1] == '\'' {
key = key[1 : len(key)-1]
}
}
}
if key == selector {
if n.child == nil {
return value.Value, nil
}
Expand All @@ -518,7 +536,7 @@ func (n *selectorNode) filter(node ast.Node) (ast.Node, error) {
case ast.MappingValueType:
value := node.(*ast.MappingValueNode)
key := value.Key.GetToken().Value
if key == n.selector {
if key == selector {
if n.child == nil {
return value.Value, nil
}
Expand Down Expand Up @@ -571,7 +589,9 @@ func (n *selectorNode) replace(node ast.Node, target ast.Node) error {
}

func (n *selectorNode) String() string {
s := fmt.Sprintf(".%s", n.selector)
var builder PathBuilder
selector := builder.normalizeSelectorName(n.selector)
s := fmt.Sprintf(".%s", selector)
if n.child != nil {
s += n.child.String()
}
Expand Down
7 changes: 7 additions & 0 deletions path_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ store:
bicycle:
color: red
price: 19.95
bicycle*unicycle:
price: 20.25
`
tests := []struct {
name string
Expand Down Expand Up @@ -97,6 +99,11 @@ store:
path: builder().Root().Child("store").Child("bicycle").Child("price").Build(),
expected: float64(19.95),
},
{
name: `$.store.'bicycle*unicycle'.price`,
path: builder().Root().Child("store").Child(`bicycle*unicycle`).Child("price").Build(),
expected: float64(20.25),
},
}
t.Run("PathString", func(t *testing.T) {
for _, test := range tests {
Expand Down
89 changes: 89 additions & 0 deletions yaml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package yaml_test

import (
"bytes"
"strings"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -1161,6 +1162,94 @@ hoge:
})
}

func TestCommentMapRoundTrip(t *testing.T) {
// test that an unmarshal and marshal round trip retains comments.
// if expect is empty, the test will use the input as the expected result.
tests := []struct {
name string
source string
expect string
encodeOptions []yaml.EncodeOption
}{
{
name: "simple map",
source: `
# head
a: 1 # line
# foot
`,
},
{
name: "nesting",
source: `
- 1 # one
- foo:
a: b
# c comment
c: d # d comment
"e#f": g # g comment
h.i: j # j comment
"k.#l": m # m comment
`,
},
{
name: "single quotes",
source: `'a#b': c # c comment`,
encodeOptions: []yaml.EncodeOption{yaml.UseSingleQuote(true)},
},
{
name: "single quotes added in encode",
source: `a#b: c # c comment`,
encodeOptions: []yaml.EncodeOption{yaml.UseSingleQuote(true)},
expect: `'a#b': c # c comment`,
},
{
name: "double quotes quotes transformed to single quotes",
source: `"a#b": c # c comment`,
encodeOptions: []yaml.EncodeOption{yaml.UseSingleQuote(true)},
expect: `'a#b': c # c comment`,
},
{
name: "single quotes quotes transformed to double quotes",
source: `'a#b': c # c comment`,
expect: `"a#b": c # c comment`,
},
{
name: "single quotes removed",
source: `'a': b # b comment`,
expect: `a: b # b comment`,
},
{
name: "double quotes removed",
source: `"a": b # b comment`,
expect: `a: b # b comment`,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var val any
cm := yaml.CommentMap{}
source := strings.TrimSpace(test.source)
if err := yaml.UnmarshalWithOptions([]byte(source), &val, yaml.CommentToMap(cm)); err != nil {
t.Fatalf("%+v", err)
}
marshaled, err := yaml.MarshalWithOptions(val, append(test.encodeOptions, yaml.WithComment(cm))...)
if err != nil {
t.Fatalf("%+v", err)
}
got := strings.TrimSpace(string(marshaled))
expect := strings.TrimSpace(test.expect)
if expect == "" {
expect = source
}
if got != expect {
t.Fatalf("expected:\n%s\ngot:\n%s\n", expect, got)
}
})

}
}

func TestRegisterCustomMarshaler(t *testing.T) {
type T struct {
Foo []byte `yaml:"foo"`
Expand Down