Skip to content

Commit

Permalink
Support for custom tags
Browse files Browse the repository at this point in the history
  • Loading branch information
sanathkr committed Aug 19, 2017
1 parent eb3733d commit ed9d249
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
@@ -0,0 +1 @@
.idea
32 changes: 28 additions & 4 deletions decode.go
Expand Up @@ -162,6 +162,7 @@ func (p *parser) scalar() *node {

func (p *parser) sequence() *node {
n := p.node(sequenceNode)
n.tag = string(p.event.tag)
p.anchor(n, p.event.anchor)
p.skip()
for p.event.typ != yaml_SEQUENCE_END_EVENT {
Expand All @@ -173,6 +174,7 @@ func (p *parser) sequence() *node {

func (p *parser) mapping() *node {
n := p.node(mappingNode)
n.tag = string(p.event.tag)
p.anchor(n, p.event.anchor)
p.skip()
for p.event.typ != yaml_MAPPING_END_EVENT {
Expand All @@ -194,10 +196,11 @@ type decoder struct {
}

var (
mapItemType = reflect.TypeOf(MapItem{})
durationType = reflect.TypeOf(time.Duration(0))
defaultMapType = reflect.TypeOf(map[interface{}]interface{}{})
ifaceType = defaultMapType.Elem()
mapItemType = reflect.TypeOf(MapItem{})
durationType = reflect.TypeOf(time.Duration(0))
defaultMapType = reflect.TypeOf(map[interface{}]interface{}{})
ifaceType = defaultMapType.Elem()
tagUnmarshalers = map[string]TagUnmarshaler{}
)

func newDecoder(strict bool) *decoder {
Expand All @@ -206,6 +209,16 @@ func newDecoder(strict bool) *decoder {
return d
}

func registerCustomTagUnmarshaler(tag string, unmarshaler TagUnmarshaler) {
tagUnmarshalers[tag] = unmarshaler
}

func unregisterCustomTagUnmarshaler(tag string) {
if _, ok := tagUnmarshalers[tag]; ok {
delete(tagUnmarshalers, tag)
}
}

func (d *decoder) terror(n *node, tag string, out reflect.Value) {
if n.tag != "" {
tag = n.tag
Expand Down Expand Up @@ -295,6 +308,17 @@ func (d *decoder) unmarshal(n *node, out reflect.Value) (good bool) {
default:
panic("internal error: unknown node kind: " + strconv.Itoa(n.kind))
}

// If the node has a tag, and a custom tag unmarshaler is registered,
// then call it to unmarshal rest of the tree
if good && len(n.tag) > 0 {
if unmarshaller, found := tagUnmarshalers[n.tag]; found {
tagSuffix := n.tag[1:] // Remove starting ! from tag
newOutput := unmarshaller.UnmarshalYAMLTag(tagSuffix, out)
out.Set(newOutput)
}
}

return good
}

Expand Down
80 changes: 78 additions & 2 deletions decode_test.go
Expand Up @@ -2,13 +2,14 @@ package yaml_test

import (
"errors"
. "gopkg.in/check.v1"
"gopkg.in/yaml.v2"
"math"
"net"
"reflect"
"strings"
"time"

. "gopkg.in/check.v1"
"gopkg.in/yaml.v2"
)

var unmarshalIntTest = 123
Expand Down Expand Up @@ -987,6 +988,81 @@ func (s *S) TestUnmarshalStrict(c *C) {
c.Check(err, ErrorMatches, "yaml: unmarshal errors:\n line 1: field c not found in struct struct { A int; B int }")
}

type tagUnmarshalerType struct {
}

func (t *tagUnmarshalerType) UnmarshalYAMLTag(tag string, fieldValue reflect.Value) reflect.Value {

output := reflect.ValueOf(make(map[string]interface{}))
key := reflect.ValueOf(tag)

output.SetMapIndex(key, fieldValue)

return output
}

func (s *S) TestTagUnmarshalNonStringPrimitiveValue(c *C) {
// All values after the "Tag" will be converted to string
a := `some: !TheTag 1`
var out map[string]interface{}

un := &tagUnmarshalerType{}
yaml.RegisterTagUnmarshaler("!TheTag", un)
yaml.Unmarshal([]byte(a), &out)
c.Assert(out, DeepEquals, map[string]interface{}{"some": map[string]interface{}{"TheTag": "1"}})
yaml.UnRegisterTagUnmarshaler("!TheTag")
}

func (s *S) TestTagUnmarshalToMap(c *C) {
a := `some: !TheTag hello`
var out map[string]interface{}

un := &tagUnmarshalerType{}
yaml.RegisterTagUnmarshaler("!TheTag", un)
yaml.Unmarshal([]byte(a), &out)
c.Assert(out, DeepEquals, map[string]interface{}{"some": map[string]interface{}{"TheTag": "hello"}})
yaml.UnRegisterTagUnmarshaler("!TheTag")
}

func (s *S) TestTagUnmarshalWithArrayValue(c *C) {

a := "key:\n some: !TheTag ['a', 'b']"
var out map[string]map[string]interface{}

un := &tagUnmarshalerType{}
yaml.RegisterTagUnmarshaler("!TheTag", un)
yaml.Unmarshal([]byte(a), &out)
c.Assert(out, DeepEquals, map[string]map[string]interface{}{"key": {"some": map[string]interface{}{"TheTag": []interface{}{"a", "b"}}}})
yaml.UnRegisterTagUnmarshaler("!TheTag")
}

func (s *S) TestTagUnmarshalWithMapValue(c *C) {

a := "some: !Tag {'a': 'b'}"
var out map[string]map[string]interface{}

un := &tagUnmarshalerType{}
yaml.RegisterTagUnmarshaler("!Tag", un)
yaml.Unmarshal([]byte(a), &out)
c.Assert(out, DeepEquals, map[string]map[string]interface{}{"some": {"Tag": map[string]interface{}{"a": "b"}}})
yaml.UnRegisterTagUnmarshaler("!Tag")
}

func (s *S) TestTagUnmarshalWithNestedTags(c *C) {

a := "some: !Tag [!OtherTag 'val1', 'val2']"
var out map[string]interface{}

un := &tagUnmarshalerType{}
yaml.RegisterTagUnmarshaler("!Tag", un)
yaml.RegisterTagUnmarshaler("!OtherTag", un)
yaml.Unmarshal([]byte(a), &out)
c.Assert(out, DeepEquals, map[string]interface{}{"some": map[string]interface{}{"Tag": []interface{}{ map[string]interface{} {"OtherTag": "val1"} , "val2"}}})
yaml.UnRegisterTagUnmarshaler("!Tag")
yaml.UnRegisterTagUnmarshaler("!OtherTag")
}


//var data []byte
//func init() {
// var err error
Expand Down
17 changes: 17 additions & 0 deletions yaml.go
Expand Up @@ -32,6 +32,15 @@ type Unmarshaler interface {
UnmarshalYAML(unmarshal func(interface{}) error) error
}

type TagUnmarshallerDecoder struct {
d *decoder
}

// The Tag Unmarshaler interface
type TagUnmarshaler interface {
UnmarshalYAMLTag(tag string, out reflect.Value) reflect.Value
}

// The Marshaler interface may be implemented by types to customize their
// behavior when being marshaled into a YAML document. The returned value
// is marshaled in place of the original value implementing Marshaler.
Expand Down Expand Up @@ -156,6 +165,14 @@ func Marshal(in interface{}) (out []byte, err error) {
return
}

func RegisterTagUnmarshaler(tag string, unmarshaler TagUnmarshaler) {
registerCustomTagUnmarshaler(tag, unmarshaler)
}

func UnRegisterTagUnmarshaler(tag string) {
unregisterCustomTagUnmarshaler(tag)
}

func handleErr(err *error) {
if v := recover(); v != nil {
if e, ok := v.(yamlError); ok {
Expand Down

0 comments on commit ed9d249

Please sign in to comment.