Skip to content

Commit

Permalink
Supports dynamically switch encode and decode processing for a given …
Browse files Browse the repository at this point in the history
…type (#368)
  • Loading branch information
goccy committed Apr 2, 2023
1 parent 1160c31 commit 4052b05
Show file tree
Hide file tree
Showing 7 changed files with 275 additions and 3 deletions.
47 changes: 45 additions & 2 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type Decoder struct {
referenceReaders []io.Reader
anchorNodeMap map[string]ast.Node
anchorValueMap map[string]reflect.Value
customUnmarshalerMap map[reflect.Type]func(interface{}, []byte) error
toCommentMap CommentMap
opts []DecodeOption
referenceFiles []string
Expand All @@ -50,6 +51,7 @@ func NewDecoder(r io.Reader, opts ...DecodeOption) *Decoder {
reader: r,
anchorNodeMap: map[string]ast.Node{},
anchorValueMap: map[string]reflect.Value{},
customUnmarshalerMap: map[reflect.Type]func(interface{}, []byte) error{},
opts: opts,
referenceReaders: []io.Reader{},
referenceFiles: []string{},
Expand Down Expand Up @@ -638,8 +640,38 @@ type jsonUnmarshaler interface {
UnmarshalJSON([]byte) error
}

func (d *Decoder) existsTypeInCustomUnmarshalerMap(t reflect.Type) bool {
if _, exists := d.customUnmarshalerMap[t]; exists {
return true
}

globalCustomUnmarshalerMu.Lock()
defer globalCustomUnmarshalerMu.Unlock()
if _, exists := globalCustomUnmarshalerMap[t]; exists {
return true
}
return false
}

func (d *Decoder) unmarshalerFromCustomUnmarshalerMap(t reflect.Type) (func(interface{}, []byte) error, bool) {
if unmarshaler, exists := d.customUnmarshalerMap[t]; exists {
return unmarshaler, exists
}

globalCustomUnmarshalerMu.Lock()
defer globalCustomUnmarshalerMu.Unlock()
if unmarshaler, exists := globalCustomUnmarshalerMap[t]; exists {
return unmarshaler, exists
}
return nil, false
}

func (d *Decoder) canDecodeByUnmarshaler(dst reflect.Value) bool {
iface := dst.Addr().Interface()
ptrValue := dst.Addr()
if d.existsTypeInCustomUnmarshalerMap(ptrValue.Type()) {
return true
}
iface := ptrValue.Interface()
switch iface.(type) {
case BytesUnmarshalerContext:
return true
Expand All @@ -662,7 +694,18 @@ func (d *Decoder) canDecodeByUnmarshaler(dst reflect.Value) bool {
}

func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, src ast.Node) error {
iface := dst.Addr().Interface()
ptrValue := dst.Addr()
if unmarshaler, exists := d.unmarshalerFromCustomUnmarshalerMap(ptrValue.Type()); exists {
b, err := d.unmarshalableDocument(src)
if err != nil {
return errors.Wrapf(err, "failed to UnmarshalYAML")
}
if err := unmarshaler(ptrValue.Interface(), b); err != nil {
return errors.Wrapf(err, "failed to UnmarshalYAML")
}
return nil
}
iface := ptrValue.Interface()

if unmarshaler, ok := iface.(BytesUnmarshalerContext); ok {
b, err := d.unmarshalableDocument(src)
Expand Down
48 changes: 48 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1858,6 +1858,54 @@ func TestDecoder_UseJSONUnmarshaler(t *testing.T) {
}
}

func TestDecoder_CustomUnmarshaler(t *testing.T) {
t.Run("override struct type", func(t *testing.T) {
type T struct {
Foo string `yaml:"foo"`
}
src := []byte(`foo: "bar"`)
var v T
if err := yaml.UnmarshalWithOptions(src, &v, yaml.CustomUnmarshaler[T](func(dst *T, b []byte) error {
if !bytes.Equal(src, b) {
t.Fatalf("failed to get decode target buffer. expected %q but got %q", src, b)
}
var v T
if err := yaml.Unmarshal(b, &v); err != nil {
return err
}
if v.Foo != "bar" {
t.Fatal("failed to decode")
}
dst.Foo = "bazbaz" // assign another value to target
return nil
})); err != nil {
t.Fatal(err)
}
if v.Foo != "bazbaz" {
t.Fatalf("failed to switch to custom unmarshaler. got: %v", v.Foo)
}
})
t.Run("override bytes type", func(t *testing.T) {
type T struct {
Foo []byte `yaml:"foo"`
}
src := []byte(`foo: "bar"`)
var v T
if err := yaml.UnmarshalWithOptions(src, &v, yaml.CustomUnmarshaler[[]byte](func(dst *[]byte, b []byte) error {
if !bytes.Equal(b, []byte(`"bar"`)) {
t.Fatalf("failed to get target buffer: %q", b)
}
*dst = []byte("bazbaz")
return nil
})); err != nil {
t.Fatal(err)
}
if !bytes.Equal(v.Foo, []byte("bazbaz")) {
t.Fatalf("failed to switch to custom unmarshaler. got: %q", v.Foo)
}
})
}

type unmarshalContext struct {
v int
}
Expand Down
43 changes: 43 additions & 0 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type Encoder struct {
useJSONMarshaler bool
anchorCallback func(*ast.AnchorNode, interface{}) error
anchorPtrToNameMap map[uintptr]string
customMarshalerMap map[reflect.Type]func(interface{}) ([]byte, error)
useLiteralStyleIfMultiline bool
commentMap map[*Path][]*Comment
written bool
Expand All @@ -56,6 +57,7 @@ func NewEncoder(w io.Writer, opts ...EncodeOption) *Encoder {
opts: opts,
indent: DefaultIndentSpaces,
anchorPtrToNameMap: map[uintptr]string{},
customMarshalerMap: map[reflect.Type]func(interface{}) ([]byte, error){},
line: 1,
column: 1,
offset: 0,
Expand Down Expand Up @@ -273,10 +275,39 @@ type jsonMarshaler interface {
MarshalJSON() ([]byte, error)
}

func (e *Encoder) existsTypeInCustomMarshalerMap(t reflect.Type) bool {
if _, exists := e.customMarshalerMap[t]; exists {
return true
}

globalCustomMarshalerMu.Lock()
defer globalCustomMarshalerMu.Unlock()
if _, exists := globalCustomMarshalerMap[t]; exists {
return true
}
return false
}

func (e *Encoder) marshalerFromCustomMarshalerMap(t reflect.Type) (func(interface{}) ([]byte, error), bool) {
if marshaler, exists := e.customMarshalerMap[t]; exists {
return marshaler, exists
}

globalCustomMarshalerMu.Lock()
defer globalCustomMarshalerMu.Unlock()
if marshaler, exists := globalCustomMarshalerMap[t]; exists {
return marshaler, exists
}
return nil, false
}

func (e *Encoder) canEncodeByMarshaler(v reflect.Value) bool {
if !v.CanInterface() {
return false
}
if e.existsTypeInCustomMarshalerMap(v.Type()) {
return true
}
iface := v.Interface()
switch iface.(type) {
case BytesMarshalerContext:
Expand All @@ -302,6 +333,18 @@ func (e *Encoder) canEncodeByMarshaler(v reflect.Value) bool {
func (e *Encoder) encodeByMarshaler(ctx context.Context, v reflect.Value, column int) (ast.Node, error) {
iface := v.Interface()

if marshaler, exists := e.marshalerFromCustomMarshalerMap(v.Type()); exists {
doc, err := marshaler(iface)
if err != nil {
return nil, errors.Wrapf(err, "failed to MarshalYAML")
}
node, err := e.encodeDocument(doc)
if err != nil {
return nil, errors.Wrapf(err, "failed to encode document")
}
return node, nil
}

if marshaler, ok := iface.(BytesMarshalerContext); ok {
doc, err := marshaler.MarshalYAML(ctx)
if err != nil {
Expand Down
37 changes: 36 additions & 1 deletion encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ import (
"bytes"
"context"
"fmt"
"github.com/goccy/go-yaml/parser"
"math"
"reflect"
"strconv"
"testing"
"time"

"github.com/goccy/go-yaml/parser"

"github.com/goccy/go-yaml"
"github.com/goccy/go-yaml/ast"
)
Expand Down Expand Up @@ -1177,6 +1178,40 @@ a:
}
}

func TestEncoder_CustomMarshaler(t *testing.T) {
t.Run("override struct type", func(t *testing.T) {
type T struct {
Foo string `yaml:"foo"`
}
b, err := yaml.MarshalWithOptions(&T{Foo: "bar"}, yaml.CustomMarshaler[T](func(v T) ([]byte, error) {
return []byte(`"override"`), nil
}))
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(b, []byte("\"override\"\n")) {
t.Fatalf("failed to switch to custom marshaler. got: %q", b)
}
})
t.Run("override bytes type", func(t *testing.T) {
type T struct {
Foo []byte `yaml:"foo"`
}
b, err := yaml.MarshalWithOptions(&T{Foo: []byte("bar")}, yaml.CustomMarshaler[[]byte](func(v []byte) ([]byte, error) {
if !bytes.Equal(v, []byte("bar")) {
t.Fatalf("failed to get src buffer: %q", v)
}
return []byte(`override`), nil
}))
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(b, []byte("foo: override\n")) {
t.Fatalf("failed to switch to custom marshaler. got: %q", b)
}
})
}

func TestEncoder_MultipleDocuments(t *testing.T) {
var buf bytes.Buffer
enc := yaml.NewEncoder(&buf)
Expand Down
30 changes: 30 additions & 0 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package yaml

import (
"io"
"reflect"

"github.com/goccy/go-yaml/ast"
)
Expand Down Expand Up @@ -94,6 +95,20 @@ func UseJSONUnmarshaler() DecodeOption {
}
}

// CustomUnmarshaler overrides any decoding process for the type specified in generics.
//
// NOTE: If RegisterCustomUnmarshaler and CustomUnmarshaler of DecodeOption are specified for the same type,
// the CustomUnmarshaler specified in DecodeOption takes precedence.
func CustomUnmarshaler[T any](unmarshaler func(*T, []byte) error) DecodeOption {
return func(d *Decoder) error {
var typ *T
d.customUnmarshalerMap[reflect.TypeOf(typ)] = func(v interface{}, b []byte) error {
return unmarshaler(v.(*T), b)
}
return nil
}
}

// EncodeOption functional option type for Encoder
type EncodeOption func(e *Encoder) error

Expand Down Expand Up @@ -165,6 +180,21 @@ func UseJSONMarshaler() EncodeOption {
}
}

// CustomMarshaler overrides any encoding process for the type specified in generics.
//
// NOTE: If type T implements MarshalYAML for pointer receiver, the type specified in CustomMarshaler must be *T.
// If RegisterCustomMarshaler and CustomMarshaler of EncodeOption are specified for the same type,
// the CustomMarshaler specified in EncodeOption takes precedence.
func CustomMarshaler[T any](marshaler func(T) ([]byte, error)) EncodeOption {
return func(e *Encoder) error {
var typ T
e.customMarshalerMap[reflect.TypeOf(typ)] = func(v interface{}) ([]byte, error) {
return marshaler(v.(T))
}
return nil
}
}

// CommentPosition type of the position for comment.
type CommentPosition int

Expand Down
40 changes: 40 additions & 0 deletions yaml.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"bytes"
"context"
"io"
"reflect"
"sync"

"github.com/goccy/go-yaml/ast"
"github.com/goccy/go-yaml/internal/errors"
Expand Down Expand Up @@ -248,3 +250,41 @@ func JSONToYAML(bytes []byte) ([]byte, error) {
}
return out, nil
}

var (
globalCustomMarshalerMu sync.Mutex
globalCustomUnmarshalerMu sync.Mutex
globalCustomMarshalerMap = map[reflect.Type]func(interface{}) ([]byte, error){}
globalCustomUnmarshalerMap = map[reflect.Type]func(interface{}, []byte) error{}
)

// RegisterCustomMarshaler overrides any encoding process for the type specified in generics.
// If you want to switch the behavior for each encoder, use `CustomMarshaler` defined as EncodeOption.
//
// NOTE: If type T implements MarshalYAML for pointer receiver, the type specified in RegisterCustomMarshaler must be *T.
// If RegisterCustomMarshaler and CustomMarshaler of EncodeOption are specified for the same type,
// the CustomMarshaler specified in EncodeOption takes precedence.
func RegisterCustomMarshaler[T any](marshaler func(T) ([]byte, error)) {
globalCustomMarshalerMu.Lock()
defer globalCustomMarshalerMu.Unlock()

var typ T
globalCustomMarshalerMap[reflect.TypeOf(typ)] = func(v interface{}) ([]byte, error) {
return marshaler(v.(T))
}
}

// RegisterCustomUnmarshaler overrides any decoding process for the type specified in generics.
// If you want to switch the behavior for each decoder, use `CustomUnmarshaler` defined as DecodeOption.
//
// NOTE: If RegisterCustomUnmarshaler and CustomUnmarshaler of DecodeOption are specified for the same type,
// the CustomUnmarshaler specified in DecodeOption takes precedence.
func RegisterCustomUnmarshaler[T any](unmarshaler func(*T, []byte) error) {
globalCustomUnmarshalerMu.Lock()
defer globalCustomUnmarshalerMu.Unlock()

var typ *T
globalCustomUnmarshalerMap[reflect.TypeOf(typ)] = func(v interface{}, b []byte) error {
return unmarshaler(v.(*T), b)
}
}
Loading

0 comments on commit 4052b05

Please sign in to comment.