diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 70374c0..587bff9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,10 +17,10 @@ jobs: steps: - name: Check out code into the Go module directory - uses: actions/checkout@v4 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Go - uses: actions/setup-go@v5 + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5.5.0 with: go-version-file: go.mod @@ -32,7 +32,7 @@ jobs: RICHGO_FORCE_COLOR: 1 - name: golangci-lint - uses: golangci/golangci-lint-action@v7 + uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0 with: version: v2.1 args: --issues-exit-code=1 --timeout 10m diff --git a/.golangci.yml b/.golangci.yml index f20fcd3..9d99f7a 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -11,6 +11,7 @@ linters: - inamedparam # reports interfaces with unnamed method parameters - wrapcheck # Checks that errors returned from external packages are wrapped - err113 # Go linter to check the errors handling expressions + #- noinlineerr - paralleltest # Detects missing usage of t.Parallel() method in your Go test - testpackage # linter that makes you use a separate _test package - exhaustruct # Checks if all structure fields are initialized @@ -27,13 +28,14 @@ linters: - gocognit # revive - gocyclo # revive - lll # revive + - wsl # wsl_v5 # # Formatting only, useful in IDE but should not be forced on CI? # - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity - - wsl # add or remove empty lines + #- wsl_v5 # add or remove empty lines settings: @@ -41,7 +43,7 @@ linters: rules: yaml: files: - - 'yamlpatch/patcher.go' + - '!**/yamlpatch/patcher.go' deny: - pkg: gopkg.in/yaml.v2 desc: yaml.v2 is deprecated for new code in favor of yaml.v3 @@ -102,6 +104,8 @@ linters: - 43 - name: defer disabled: true + #- name: enforce-switch-style + # disabled: true - name: flag-parameter disabled: true - name: function-length diff --git a/csyaml/empty.go b/csyaml/empty.go new file mode 100644 index 0000000..6d5d91d --- /dev/null +++ b/csyaml/empty.go @@ -0,0 +1,44 @@ +package csyaml + +import ( + "errors" + "io" + + yaml "github.com/goccy/go-yaml" + "github.com/goccy/go-yaml/parser" +) + +// IsEmptyYAML reads one or more YAML documents from r and returns true +// if they are all empty or contain only comments. +// It will reports errors if the input is not valid YAML. +func IsEmptyYAML(r io.Reader) (bool, error) { + src, err := io.ReadAll(r) + if err != nil { + return false, err + } + + if len(src) == 0 { + return true, nil + } + + file, err := parser.ParseBytes(src, 0) + if err != nil { + if errors.Is(err, io.EOF) { + return true, nil + } + + return false, errors.New(yaml.FormatError(err, false, false)) + } + + if file == nil || len(file.Docs) == 0 { + return true, nil + } + + for _, doc := range file.Docs { + if doc.Body != nil { + return false, nil + } + } + + return true, nil +} diff --git a/csyaml/empty_test.go b/csyaml/empty_test.go new file mode 100644 index 0000000..d75347e --- /dev/null +++ b/csyaml/empty_test.go @@ -0,0 +1,114 @@ +package csyaml + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/crowdsecurity/go-cs-lib/cstest" // adjust this import to your package +) + +func TestIsEmptyYAML(t *testing.T) { + tests := []struct { + name string + input string + want bool + wantErr string + }{ + { + name: "empty document", + input: ``, + want: true, + }, + { + name: "just a key", + input: "foo:", + want: false, + }, + { + name: "just newline", + input: "\n", + want: true, + }, + { + name: "just comment", + input: "# only a comment", + want: true, + }, + { + name: "comments and empty lines", + input: "# only a comment\n\n# another one\n\n", + want: true, + }, + { + name: "empty doc with separator", + input: "---", + want: true, + }, + { + name: "empty mapping", + input: "{}", + want: false, + }, + { + name: "empty sequence", + input: "[]", + want: false, + }, + { + name: "non-empty mapping", + input: "foo: bar", + want: false, + }, + { + name: "non-empty sequence", + input: "- 1\n- 2", + want: false, + }, + { + name: "non-empty scalar", + input: "hello", + want: false, + }, + { + name: "empty scalar", + input: "''", + want: false, + }, + { + name: "explicit nil", + input: "null", + want: false, + }, + { + name: "malformed YAML", + input: "foo: [1,", + wantErr: "[1:6] sequence end token ']' not found", + }, + { + name: "multiple empty documents", + input: "---\n---\n---\n#comment", + want: true, + }, + { + name: "second document is not empty", + input: "---\nfoo: bar\n---\n#comment", + want: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := IsEmptyYAML(strings.NewReader(tc.input)) + + cstest.RequireErrorContains(t, err, tc.wantErr) + + if tc.wantErr != "" { + return + } + + assert.Equal(t, tc.want, got) + }) + } +} diff --git a/csyaml/keys.go b/csyaml/keys.go new file mode 100644 index 0000000..1c3247a --- /dev/null +++ b/csyaml/keys.go @@ -0,0 +1,54 @@ +package csyaml + +import ( + "errors" + "fmt" + "io" + + "github.com/goccy/go-yaml" +) + +// GetDocumentKeys reads all YAML documents from r and for each one +// returns a slice of its top-level keys, in order. +// +// Non-mapping documents yield an empty slice. Duplicate keys +// are not allowed and return an error. +func GetDocumentKeys(r io.Reader) ([][]string, error) { + // Decode into Go types, but force mappings into MapSlice + dec := yaml.NewDecoder(r, yaml.UseOrderedMap()) + + allKeys := make([][]string, 0) + + idx := -1 + + for { + var raw any + + idx++ + + if err := dec.Decode(&raw); err != nil { + if errors.Is(err, io.EOF) { + break + } + return nil, fmt.Errorf("position %d: %s", idx, yaml.FormatError(err, false, false)) + } + keys := []string{} + + // Only mapping nodes become MapSlice with UseOrderedMap() + if ms, ok := raw.(yaml.MapSlice); ok { + for _, item := range ms { + // Key is interface{}—here we expect strings + if ks, ok := item.Key.(string); ok { + keys = append(keys, ks) + } else { + // fallback to string form of whatever it is + keys = append(keys, fmt.Sprint(item.Key)) + } + } + } + + allKeys = append(allKeys, keys) + } + + return allKeys, nil +} diff --git a/csyaml/keys_test.go b/csyaml/keys_test.go new file mode 100644 index 0000000..50b1427 --- /dev/null +++ b/csyaml/keys_test.go @@ -0,0 +1,67 @@ +package csyaml_test + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/go-cs-lib/csyaml" +) + +func TestCollectTopLevelKeys(t *testing.T) { + tests := []struct { + name string + input string + want [][]string + wantErr string + }{ + { + name: "single mapping", + input: "a: 1\nb: 2\n", + want: [][]string{{"a", "b"}}, + }, + { + name: "duplicate keys mapping", + input: "a: 1\nb: 2\na: 3\n", + want: nil, + wantErr: `position 0: [3:1] mapping key "a" already defined at [1:1]`, + }, + { + name: "multiple documents", + input: `--- +a: 1 +b: 2 +--- +- 1 +--- +c: 1 +b: 2 +--- +"scalar" +`, + want: [][]string{{"a", "b"}, {}, {"c", "b"}, {}}, + }, + { + name: "empty input", + input: "", + want: [][]string{}, + }, + { + name: "invalid YAML", + input: "list: [1, 2,", + want: nil, + wantErr: "position 0: [1:7] sequence end token ']' not found", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + r := strings.NewReader(tc.input) + got, err := csyaml.GetDocumentKeys(r) + cstest.RequireErrorContains(t, err, tc.wantErr) + assert.Equal(t, tc.want, got) + }) + } +} diff --git a/csyaml/merge.go b/csyaml/merge.go new file mode 100644 index 0000000..7c9a547 --- /dev/null +++ b/csyaml/merge.go @@ -0,0 +1,133 @@ +// Package merge implements a deep-merge over multiple YAML documents, +// preserving key order and rejecting invalid documents. +// +// Maps are deep-merged; sequences and scalars are replaced by later inputs. +// Type mismatches result in an error. +// +// Adapted from https://github.com/uber-go/config/tree/master/internal/merge +package csyaml + +import ( + "bytes" + "errors" + "fmt" + "io" + "reflect" + + "github.com/goccy/go-yaml" +) + +// Merge reads each YAML source in inputs, merges them in order (later +// sources override earlier), and returns the result as a bytes.Buffer. +// Always runs in strict mode: type mismatches or duplicate keys cause an error. +func Merge(inputs [][]byte) (*bytes.Buffer, error) { + var merged any + hasContent := false + for idx, data := range inputs { + dec := yaml.NewDecoder(bytes.NewReader(data), yaml.UseOrderedMap(), yaml.Strict()) + + var value any + if err := dec.Decode(&value); err != nil { + if errors.Is(err, io.EOF) { + continue + } + return nil, fmt.Errorf("decoding document %d: %s", idx, yaml.FormatError(err, false, false)) + } + hasContent = true + + mergedValue, err := mergeValue(merged, value) + if err != nil { + return nil, err + } + + merged = mergedValue + } + + buf := &bytes.Buffer{} + if merged == nil && !hasContent { + return buf, nil + } + + enc := yaml.NewEncoder(buf) + if err := enc.Encode(merged); err != nil { + return nil, fmt.Errorf("encoding merged YAML: %w", err) + } + + return buf, nil +} + +// mergeValue merges from+into in strict mode. +func mergeValue(into, from any) (any, error) { + if into == nil { + return from, nil + } + + if from == nil { + return nil, nil + } + + // Scalars: override + if !isMapping(into) && !isSequence(into) && !isMapping(from) && !isSequence(from) { + return from, nil + } + + // Sequences: replace + if isSequence(into) && isSequence(from) { + return from, nil + } + + // Mappings: deep-merge + if mi, ok := into.(yaml.MapSlice); ok { + if mf, ok2 := from.(yaml.MapSlice); ok2 { + return mergeMap(mi, mf) + } + } + + // Type mismatch: strict + return nil, fmt.Errorf("can't merge a %s into a %s", describe(from), describe(into)) +} + +// mergeMap deep-merges two ordered maps (MapSlice) in strict mode. +func mergeMap(into, from yaml.MapSlice) (yaml.MapSlice, error) { + out := make(yaml.MapSlice, len(into)) + copy(out, into) + for _, item := range from { + matched := false + for i, existing := range out { + if !reflect.DeepEqual(existing.Key, item.Key) { + continue + } + + mergedVal, err := mergeValue(existing.Value, item.Value) + if err != nil { + return nil, err + } + out[i].Value = mergedVal + matched = true + } + if !matched { + out = append(out, yaml.MapItem{Key: item.Key, Value: item.Value}) + } + } + return out, nil +} + +func isMapping(i any) bool { + _, ok := i.(yaml.MapSlice) + return ok +} + +func isSequence(i any) bool { + _, ok := i.([]any) + return ok +} + +func describe(i any) string { + if isMapping(i) { + return "mapping" + } + if isSequence(i) { + return "sequence" + } + return "scalar" +} diff --git a/csyaml/merge_test.go b/csyaml/merge_test.go new file mode 100644 index 0000000..ceadb35 --- /dev/null +++ b/csyaml/merge_test.go @@ -0,0 +1,209 @@ +package csyaml_test + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/go-cs-lib/csyaml" +) + +func TestMergeYAML(t *testing.T) { + tests := []struct { + name string + inputs []string + want string + wantErr string + }{ + { + name: "single doc passes through", + inputs: []string{"a: 1\nb: 2\n"}, + want: "a: 1\nb: 2\n", + }, + { + name: "merge maps deep", + inputs: []string{ + "one: 1\ntwo: 2\n", + "two: 20\nthree: 3\n", + }, + want: "one: 1\ntwo: 20\nthree: 3\n", + }, + { + name: "sequence replaced", + inputs: []string{ + "list: [1,2,3]\n", + "list: [4,5]\n", + }, + want: "list:\n- 4\n- 5\n", + }, + { + name: "scalar override", + inputs: []string{ + "foo: bar\n", + "foo: baz\n", + }, + want: "foo: baz\n", + }, + { + name: "type mismatch error", + inputs: []string{"foo: 1\n", "foo:\n - a\n"}, + wantErr: "can't merge a sequence into a scalar", + }, + { + name: "invalid yaml error", + inputs: []string{"ref: *foo\n"}, // undefined alias + wantErr: `decoding document 0: [1:7] could not find alias "foo"`, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var bs [][]byte + for _, s := range tc.inputs { + bs = append(bs, []byte(s)) + } + + buf, err := csyaml.Merge(bs) + cstest.RequireErrorContains(t, err, tc.wantErr) + if tc.wantErr != "" { + require.Nil(t, buf) + return + } + + require.NotNil(t, buf) + assert.Equal(t, tc.want, buf.String()) + }) + } +} + +func TestEmptyVsNilSources(t *testing.T) { + tests := []struct { + desc string + sources [][]byte + expect string + }{ + {"empty base", [][]byte{nil, []byte("foo: bar\n")}, "foo: bar\n"}, + {"empty override", [][]byte{[]byte("foo: bar\n"), nil}, "foo: bar\n"}, + {"both empty", [][]byte{nil, nil}, ""}, + {"null base", [][]byte{[]byte("~\n"), []byte("foo: bar\n")}, "foo: bar\n"}, + {"explicit null override", [][]byte{[]byte("foo: bar\n"), []byte("~\n")}, "null\n"}, + {"empty base & null override", [][]byte{nil, []byte("~\n")}, "null\n"}, + {"null base & empty override", [][]byte{[]byte("~\n"), nil}, "null\n"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + merged, err := csyaml.Merge(tt.sources) + require.NoError(t, err) + assert.Equal(t, tt.expect, merged.String()) + }) + } +} + +func TestDuplicateKeyError(t *testing.T) { + src := []byte("{foo: bar, foo: baz}") + _, err := csyaml.Merge([][]byte{src}) + cstest.RequireErrorContains(t, err, `decoding document 0: [1:12] mapping key "foo" already defined at [1:2]`) +} + +func TestTabsInSource(t *testing.T) { + src := []byte("foo:\n\tbar: baz") + _, err := csyaml.Merge([][]byte{src}) + cstest.RequireErrorContains(t, err, "decoding document 0: [2:1] found character '\t' that cannot start any token") +} + +func TestNestedDeepMergePreservesOrder(t *testing.T) { + left := ` +settings: + ui: + theme: light + toolbar: + - cut + - copy +` + right := ` +settings: + ui: + toolbar: + - paste + security: strict +` + expect := `settings: + ui: + theme: light + toolbar: + - paste + security: strict +` + merged, err := csyaml.Merge([][]byte{[]byte(left), []byte(right)}) + require.NoError(t, err) + assert.Equal(t, expect, merged.String()) +} + +// Don't coerce boolean-like strings to true/false (YAML 1.2 / goccy/go-yaml behavior). +func TestBooleanNoCoercion(t *testing.T) { + tests := []struct { + in, out string + }{ + {"foo: yes", `foo: "yes"`}, + {"foo: YES", `foo: "YES"`}, + {"foo: no", `foo: "no"`}, + {"foo: NO", `foo: "NO"`}, + {"foo: on", `foo: "on"`}, + {"foo: ON", `foo: "ON"`}, + {"foo: off", `foo: "off"`}, + {"foo: OFF", `foo: "OFF"`}, + } + + for _, tt := range tests { + t.Run(tt.in, func(t *testing.T) { + buf, err := csyaml.Merge([][]byte{nil, []byte(tt.in)}) + require.NoError(t, err) + assert.Equal(t, tt.out, strings.TrimSuffix(buf.String(), "\n")) + }) + } +} + +// Do coerce boolean-like values to true/false (YAML 1.1 / yaml.v3 behavior). +// func TestBooleanCoercion(t *testing.T) { +// tests := []struct { +// in, out string +// }{ +// {"yes\n", "true\n"}, +// {"YES\n", "true\n"}, +// {"no\n", "false\n"}, +// {"NO\n", "false\n"}, +// {"on\n", "true\n"}, +// {"ON\n", "true\n"}, +// {"off\n", "false\n"}, +// {"OFF\n", "false\n"}, +// } +// +// for _, tt := range tests { +// t.Run(tt.in, func(t *testing.T) { +// buf, err := csyaml.Merge([][]byte{nil, []byte(tt.in)}) +// require.NoError(t, err) +// assert.Equal(t, tt.out, buf.String()) +// }) +// } +//} + +func TestExplicitNilOverride(t *testing.T) { + base := []byte("foo: {one: two}\n") + override := []byte("foo: ~\n") + merged, err := csyaml.Merge([][]byte{base, override}) + require.NoError(t, err) + assert.Equal(t, "foo: null\n", merged.String()) +} + +func TestOrderPreservation(t *testing.T) { + left := []byte("a: 1\nb: 2\n") + right := []byte("c: 3\nb: 20\n") + expect := "a: 1\nb: 20\nc: 3\n" + merged, err := csyaml.Merge([][]byte{left, right}) + require.NoError(t, err) + assert.Equal(t, expect, merged.String()) +} diff --git a/csyaml/patcher.go b/csyaml/patcher.go new file mode 100644 index 0000000..8e9ffe7 --- /dev/null +++ b/csyaml/patcher.go @@ -0,0 +1,159 @@ +package csyaml + +import ( + "bytes" + "errors" + "fmt" + "io" + "os" + + "github.com/goccy/go-yaml" + "github.com/sirupsen/logrus" +) + +type Patcher struct { + BaseFilePath string + PatchFilePath string + quiet bool +} + +func NewPatcher(filePath string, suffix string) *Patcher { + return &Patcher{ + BaseFilePath: filePath, + PatchFilePath: filePath + suffix, + quiet: false, + } +} + +// SetQuiet sets the quiet flag, which will log as DEBUG_LEVEL instead of INFO. +func (p *Patcher) SetQuiet(quiet bool) { + p.quiet = quiet +} + +// read a single YAML file, check for errors (the merge package doesn't) then return the content as bytes. +func readYAML(filePath string) ([]byte, error) { + content, err := os.ReadFile(filePath) + if err != nil { + return nil, fmt.Errorf("while reading yaml file: %w", err) + } + + var yamlMap map[any]any + + if err = yaml.Unmarshal(content, &yamlMap); err != nil { + return nil, fmt.Errorf("%s: %s", filePath, yaml.FormatError(err, false, false)) + } + + return content, nil +} + +// MergedPatchContent reads a YAML file and, if it exists, its patch file, +// then merges them and returns it serialized. +func (p *Patcher) MergedPatchContent() ([]byte, error) { + base, err := readYAML(p.BaseFilePath) + if err != nil { + return nil, err + } + + over, err := readYAML(p.PatchFilePath) + if errors.Is(err, os.ErrNotExist) { + return base, nil + } + + if err != nil { + return nil, err + } + + logf := logrus.Infof + if p.quiet { + logf = logrus.Debugf + } + + logf("Loading yaml file: '%s' with additional values from '%s'", p.BaseFilePath, p.PatchFilePath) + + // strict mode true, will raise errors for duplicate map keys and + // overriding with a different type + patched, err := Merge([][]byte{base, over}) + if err != nil { + return nil, err + } + + return patched.Bytes(), nil +} + +// read multiple YAML documents inside a file, and writes them to a buffer +// separated by the appropriate '---' terminators. +func decodeDocuments(file *os.File, buf *bytes.Buffer, finalDashes bool) error { + dec := yaml.NewDecoder(file, yaml.Strict()) + + dashTerminator := false + + for { + yml := make(map[any]any) + + err := dec.Decode(&yml) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + + return fmt.Errorf("while decoding %s: %s", file.Name(), yaml.FormatError(err, false, false)) + } + + docBytes, err := yaml.Marshal(&yml) + if err != nil { + return fmt.Errorf("while marshaling %s: %w", file.Name(), err) + } + + if dashTerminator { + buf.WriteString("---\n") + } + + buf.Write(docBytes) + + dashTerminator = true + } + + if dashTerminator && finalDashes { + buf.WriteString("---\n") + } + + return nil +} + +// PrependedPatchContent collates the base .yaml file with the .yaml.patch, by putting +// the content of the patch BEFORE the base document. The result is a multi-document +// YAML in all cases, even if the base and patch files are single documents. +func (p *Patcher) PrependedPatchContent() ([]byte, error) { + patchFile, err := os.Open(p.PatchFilePath) + // optional file, ignore if it does not exist + if err != nil && !errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("while opening %s: %w", p.PatchFilePath, err) + } + + var result bytes.Buffer + + if patchFile != nil { + if err = decodeDocuments(patchFile, &result, true); err != nil { + return nil, err + } + + logf := logrus.Infof + + if p.quiet { + logf = logrus.Debugf + } + + logf("Prepending yaml: '%s' with '%s'", p.BaseFilePath, p.PatchFilePath) + } + + baseFile, err := os.Open(p.BaseFilePath) + if err != nil { + return nil, fmt.Errorf("while opening %s: %w", p.BaseFilePath, err) + } + + if err = decodeDocuments(baseFile, &result, false); err != nil { + return nil, err + } + + return result.Bytes(), nil +} diff --git a/csyaml/patcher_test.go b/csyaml/patcher_test.go new file mode 100644 index 0000000..6617da3 --- /dev/null +++ b/csyaml/patcher_test.go @@ -0,0 +1,288 @@ +package csyaml_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/go-cs-lib/csyaml" +) + +func TestMergedPatchContent(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + base string + patch string + expected string + expectedErr string + }{ + { + "invalid yaml in base", + "notayaml", + "", + "", + "config.yaml: [1:1] string was used where mapping is expected", + }, + { + "invalid yaml in base (detailed message)", + "notayaml", + "", + "", + "config.yaml: [1:1] string was used where mapping is expected", + }, + { + "invalid yaml in patch", + "", + "notayaml", + "", + "config.yaml.local: [1:1] string was used where mapping is expected", + }, + { + "invalid yaml in patch (detailed message)", + "", + "notayaml", + "", + "config.yaml.local: [1:1] string was used where mapping is expected", + }, + { + "basic merge", + "{'first':{'one':1,'two':2},'second':{'three':3}}", + "{'first':{'one':10,'dos':2}}", + "{'first':{'one':10,'dos':2,'two':2},'second':{'three':3}}", + "", + }, + + // bools and zero values; here the "mergo" package had issues + // so we used something simpler. + + { + "don't convert on/off to boolean", + "bool: on", + "bool: off", + "bool: off", + "", + }, + { + "string is not a bool - on to off", + "{'bool': 'on'}", + "{'bool': 'off'}", + "{'bool': 'off'}", + "", + }, + { + "string is not a bool - off to on", + "{'bool': 'off'}", + "{'bool': 'on'}", + "{'bool': 'on'}", + "", + }, + { + "bool merge - true to false", + "{'bool': true}", + "{'bool': false}", + "{'bool': false}", + "", + }, + { + "bool merge - false to true", + "{'bool': false}", + "{'bool': true}", + "{'bool': true}", + "", + }, + { + "string merge - value to value", + "{'string': 'value'}", + "{'string': ''}", + "{'string': ''}", + "", + }, + { + "sequence merge - value to empty", + "{'sequence': [1, 2]}", + "{'sequence': []}", + "{'sequence': []}", + "", + }, + { + "map merge - value to value", + "{'map': {'one': 1, 'two': 2}}", + "{'map': {}}", + "{'map': {'one': 1, 'two': 2}}", + "", + }, + + // mismatched types + + { + "can't merge a sequence into a mapping", + "map: {'key': 'value'}", + "map: ['value1', 'value2']", + "", + "can't merge a sequence into a mapping", + }, + { + "can't merge a scalar into a mapping", + "map: {'key': 'value'}", + "map: 3", + "", + "can't merge a scalar into a mapping", + }, + { + "can't merge a mapping into a sequence", + "sequence: ['value1', 'value2']", + "sequence: {'key': 'value'}", + "", + "can't merge a mapping into a sequence", + }, + { + "can't merge a scalar into a sequence", + "sequence: ['value1', 'value2']", + "sequence: 3", + "", + "can't merge a scalar into a sequence", + }, + { + "can't merge a sequence into a scalar", + "scalar: true", + "scalar: ['value1', 'value2']", + "", + "can't merge a sequence into a scalar", + }, + { + "can't merge a mapping into a scalar", + "scalar: true", + "scalar: {'key': 'value'}", + "", + "can't merge a mapping into a scalar", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + dirPath := t.TempDir() + + configPath := filepath.Join(dirPath, "config.yaml") + patchPath := filepath.Join(dirPath, "config.yaml.local") + err := os.WriteFile(configPath, []byte(tc.base), 0o600) + require.NoError(t, err) + + err = os.WriteFile(patchPath, []byte(tc.patch), 0o600) + require.NoError(t, err) + + patcher := csyaml.NewPatcher(configPath, ".local") + patchedBytes, err := patcher.MergedPatchContent() + cstest.RequireErrorContains(t, err, tc.expectedErr) + require.YAMLEq(t, tc.expected, string(patchedBytes)) + }) + } +} + +func TestPrependedPatchContent(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + base string + patch string + expected string + expectedErr string + }{ + // we test with scalars here, because YAMLeq does not work + // with multi-document files, so we need char-to-char comparison + // which is noisy with sequences and (unordered) mappings + { + "newlines are always appended, if missing, by yaml.Marshal()", + "foo: bar", + "", + "foo: bar\n", + "", + }, + { + "prepend empty document", + "foo: bar\n", + "", + "foo: bar\n", + "", + }, + { + "prepend a document to another", + "foo: bar", + "baz: qux", + "baz: qux\n---\nfoo: bar\n", + "", + }, + { + "prepend document with same key", + "foo: true", + "foo: false", + "foo: false\n---\nfoo: true\n", + "", + }, + { + "prepend multiple documents", + "one: 1\n---\ntwo: 2\n---\none: 3", + "four: 4\n---\none: 1.1", + "four: 4\n---\none: 1.1\n---\none: 1\n---\ntwo: 2\n---\none: 3\n", + "", + }, + { + "invalid yaml in base", + "blablabla", + "", + "", + "config.yaml: [1:1] string was used where mapping is expected", + }, + { + "invalid yaml in base (detailed message)", + "blablabla", + "", + "", + "config.yaml: [1:1] string was used where mapping is expected", + }, + { + "invalid yaml in patch", + "", + "blablabla", + "", + "config.yaml.local: [1:1] string was used where mapping is expected", + }, + { + "invalid yaml in patch (detailed message)", + "", + "blablabla", + "", + "config.yaml.local: [1:1] string was used where mapping is expected", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + dirPath := t.TempDir() + + configPath := filepath.Join(dirPath, "config.yaml") + patchPath := filepath.Join(dirPath, "config.yaml.local") + + err := os.WriteFile(configPath, []byte(tc.base), 0o600) + require.NoError(t, err) + + err = os.WriteFile(patchPath, []byte(tc.patch), 0o600) + require.NoError(t, err) + + patcher := csyaml.NewPatcher(configPath, ".local") + patchedBytes, err := patcher.PrependedPatchContent() + cstest.RequireErrorContains(t, err, tc.expectedErr) + // YAMLeq does not handle multiple documents + require.Equal(t, tc.expected, string(patchedBytes)) + }) + } +} diff --git a/csyaml/split.go b/csyaml/split.go new file mode 100644 index 0000000..687e8a9 --- /dev/null +++ b/csyaml/split.go @@ -0,0 +1,35 @@ +package csyaml + +import ( + "bytes" + "io" +) + +// SplitDocuments returns a slice of byte slices, each representing a YAML document. +// +// Since preserving formatting and comments is important but the existing go packages +// all have some issue, this function attempts two strategies: one that decodes and +// re-encodes the YAML content, and another that simply splits the input text. +// If both methods return the same number of documents, we assume the text-based +// function is sufficient. It retains comments and formatting better. +// Otherwise, the round-trip version is used. It retains comments but +// the formatting may be off. The semantics of the document will still be the same +// but if it contains parsing errors, they may refer to a wrong line or column. +// +// This function returns reading errors but any parsing errors are ignored and +// trigger the text-based splitting method. +func SplitDocuments(r io.Reader) ([][]byte, error) { + input, err := io.ReadAll(r) + if err != nil { + return nil, err + } + + textDocs, errText := SplitDocumentsText(bytes.NewReader(input)) + decEncDocs, errDecEnc := SplitDocumentsDecEnc(bytes.NewReader(input)) + + if errDecEnc == nil && len(decEncDocs) != len(textDocs) { + return decEncDocs, nil + } + + return textDocs, errText +} diff --git a/csyaml/splittext.go b/csyaml/splittext.go new file mode 100644 index 0000000..9d3df58 --- /dev/null +++ b/csyaml/splittext.go @@ -0,0 +1,52 @@ +package csyaml + +import ( + "bufio" + "bytes" + "io" + "strings" +) + +// SplitDocumentsText splits a YAML input stream into separate documents by looking for the `---` separator. +// No encoding or decoding is performed; the input is treated as raw text. +// Comments and whitespace are preserved. Malformed documents are returned as-is. +func SplitDocumentsText(r io.Reader) ([][]byte, error) { + var ( + docs [][]byte + current bytes.Buffer + ) + + scanner := bufio.NewScanner(r) + + for scanner.Scan() { + line := scanner.Text() + trimmed := strings.TrimSpace(line) + + isSeparator := strings.HasPrefix(trimmed, "---") && + (trimmed == "---" || strings.HasPrefix(trimmed, "--- ")) + + // Always write the line first + current.WriteString(line) + current.WriteByte('\n') + + if isSeparator && current.Len() > len(line)+1 { // +1 for newline just added + // Separator starts a new doc → commit previous one + // (everything up to this line is the previous doc) + n := current.Len() + // rewind to just before this separator line + doc := current.Bytes()[:n-len(line)-1] + docs = append(docs, append([]byte(nil), doc...)) + current = *bytes.NewBuffer(current.Bytes()[n-len(line)-1:]) + } + } + + if err := scanner.Err(); err != nil { + return nil, err + } + + if current.Len() > 0 { + docs = append(docs, current.Bytes()) + } + + return docs, nil +} diff --git a/csyaml/splittext_test.go b/csyaml/splittext_test.go new file mode 100644 index 0000000..6f14870 --- /dev/null +++ b/csyaml/splittext_test.go @@ -0,0 +1,112 @@ +package csyaml_test + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/go-cs-lib/csyaml" +) + +func TestSplitDocumentsText(t *testing.T) { + tests := []struct { + name string + input string + want [][]byte + wantErr string + }{ + { + name: "single mapping", + input: "a: 1\nb: 2\n", + want: [][]byte{[]byte("a: 1\nb: 2\n")}, + }, + { + name: "sequence doc", + input: "- 1\n- 2\n", + want: [][]byte{[]byte("- 1\n- 2\n")}, + }, + { + name: "scalar doc", + input: "\"scalar\"\n", + want: [][]byte{[]byte("\"scalar\"\n")}, + }, + { + name: "multiple documents", + input: `--- +a: 1 +b: 2 +--- +- 1 +- 2 +--- +"scalar" +`, + want: [][]byte{ + []byte("---\na: 1\nb: 2\n"), + []byte("---\n- 1\n- 2\n"), + []byte("---\n\"scalar\"\n"), + }, + }, + { + name: "empty input", + input: "", + want: [][]byte(nil), + }, + { + name: "invalid YAML", + input: "list: [1, 2,", + want: [][]byte{[]byte("list: [1, 2,\n")}, + }, + { + name: "preserve comments", + input: `# comment 1 +a: 1 +# comment 2 +b: 2 +--- +# comment 3 +- 1 +# comment 4 +- 2 +# comment 5 +--- +# comment 6 +"scalar" +# comment 7 +`, + want: [][]byte{ + []byte("# comment 1\na: 1\n# comment 2\nb: 2\n"), + []byte("---\n# comment 3\n- 1\n# comment 4\n- 2\n# comment 5\n"), + []byte("---\n# comment 6\n\"scalar\"\n# comment 7\n"), + }, + }, + { + name: "tricky separator", + input: `--- +text: | + This is a multi-line string. + It includes a line that looks like a document separator: + --- + But it's just part of the string. +--- +key: value +`, + want: [][]byte{ + []byte("---\ntext: |\n This is a multi-line string.\n It includes a line that looks like a document separator:\n"), + []byte(" ---\n But it's just part of the string.\n"), + []byte("---\nkey: value\n"), + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + r := strings.NewReader(tc.input) + docs, err := csyaml.SplitDocumentsText(r) + cstest.RequireErrorContains(t, err, tc.wantErr) + assert.Equal(t, tc.want, docs) + }) + } +} diff --git a/csyaml/splityaml.go b/csyaml/splityaml.go new file mode 100644 index 0000000..0923331 --- /dev/null +++ b/csyaml/splityaml.go @@ -0,0 +1,48 @@ +package csyaml + +import ( + "bytes" + "errors" + "fmt" + "io" + + "gopkg.in/yaml.v3" +) + +// SplitDocumentsDecEnc splits documents from reader and returns them as +// re-encoded []byte slices, preserving comments but not exact original +// whitespace. It returns an error if any document cannot be decoded or +// re-encoded. +func SplitDocumentsDecEnc(r io.Reader) ([][]byte, error) { + dec := yaml.NewDecoder(r) + + var docs [][]byte + + idx := 0 + + for { + var node yaml.Node + if err := dec.Decode(&node); err != nil { + if errors.Is(err, io.EOF) { + break + } + + return nil, fmt.Errorf("decode doc %d: %w", idx, err) + } + + var buf bytes.Buffer + + enc := yaml.NewEncoder(&buf) + enc.SetIndent(2) + if err := enc.Encode(&node); err != nil { + return nil, fmt.Errorf("encode doc %d: %w", idx, err) + } + + _ = enc.Close() + + docs = append(docs, buf.Bytes()) + idx++ + } + + return docs, nil +} diff --git a/csyaml/splityaml_test.go b/csyaml/splityaml_test.go new file mode 100644 index 0000000..4a5816e --- /dev/null +++ b/csyaml/splityaml_test.go @@ -0,0 +1,113 @@ +package csyaml_test + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/go-cs-lib/csyaml" +) + +func TestSplitDocumentsDecEnc(t *testing.T) { + tests := []struct { + name string + input string + want [][]byte + wantErr string + }{ + { + name: "single mapping", + input: "a: 1\nb: 2\n", + want: [][]byte{[]byte("a: 1\nb: 2\n")}, + }, + { + name: "sequence doc", + input: "- 1\n- 2\n", + want: [][]byte{[]byte("- 1\n- 2\n")}, + }, + { + name: "scalar doc", + input: "\"scalar\"\n", + want: [][]byte{[]byte("\"scalar\"\n")}, + }, + { + name: "multiple documents", + input: `--- +a: 1 +b: 2 +--- +- 1 +- 2 +--- +"scalar" +`, + want: [][]byte{ + []byte("a: 1\nb: 2\n"), + []byte("- 1\n- 2\n"), + []byte("\"scalar\"\n"), + }, + }, + { + name: "empty input", + input: "", + want: [][]byte(nil), + }, + { + name: "invalid YAML", + input: "list: [1, 2,", + want: nil, + wantErr: "decode doc 0: yaml: line 1: did not find expected node content", + }, + { + name: "preserve comments", + input: `# comment 1 +a: 1 +# comment 2 +b: 2 +--- +# comment 3 +- 1 +# comment 4 +- 2 +# comment 5 +--- +# comment 6 +"scalar" +# comment 7 +`, + want: [][]byte{ + []byte("# comment 1\na: 1\n# comment 2\nb: 2\n"), + // not sure how to get rid of the extra new line here. + []byte("# comment 3\n- 1\n# comment 4\n- 2\n\n# comment 5\n"), + []byte("# comment 6\n\"scalar\"\n# comment 7\n"), + }, + }, + { + name: "tricky separator", + input: `--- +text: | + This is a multi-line string. + It includes a line that looks like a document separator: + --- + But it's just part of the string. +--- +key: value +`, + want: [][]byte{ + []byte("text: |\n This is a multi-line string.\n It includes a line that looks like a document separator:\n ---\n But it's just part of the string.\n"), + []byte("key: value\n"), + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + r := strings.NewReader(tc.input) + docs, err := csyaml.SplitDocumentsDecEnc(r) + cstest.RequireErrorContains(t, err, tc.wantErr) + assert.Equal(t, tc.want, docs) + }) + } +} diff --git a/go.mod b/go.mod index 5446ce1..5bd53b6 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,9 @@ module github.com/crowdsecurity/go-cs-lib go 1.23 require ( - github.com/blackfireio/osinfo v1.0.5 + github.com/blackfireio/osinfo v1.1.0 github.com/coreos/go-systemd/v22 v22.5.0 + github.com/goccy/go-yaml v1.18.0 github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.9.1 github.com/stretchr/testify v1.10.0 diff --git a/go.sum b/go.sum index b419592..826f733 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,13 @@ -github.com/blackfireio/osinfo v1.0.5 h1:6hlaWzfcpb87gRmznVf7wSdhysGqLRz9V/xuSdCEXrA= -github.com/blackfireio/osinfo v1.0.5/go.mod h1:Pd987poVNmd5Wsx6PRPw4+w7kLlf9iJxoRKPtPAjOrA= +github.com/blackfireio/osinfo v1.1.0 h1:1LMkMiFL42+Brx7r3MKuf7UTlXBRgebFLJQAfoFafj8= +github.com/blackfireio/osinfo v1.1.0/go.mod h1:Pd987poVNmd5Wsx6PRPw4+w7kLlf9iJxoRKPtPAjOrA= github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw= +github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= diff --git a/yamlpatch/patcher.go b/yamlpatch/patcher.go index 6f0fd12..31bd335 100644 --- a/yamlpatch/patcher.go +++ b/yamlpatch/patcher.go @@ -17,6 +17,7 @@ type Patcher struct { quiet bool } +// Deprecated: use csyaml.NewPatcher instead. func NewPatcher(filePath string, suffix string) *Patcher { return &Patcher{ BaseFilePath: filePath,