From cb3c985608ba7bedb85baf2e7e3579067c301978 Mon Sep 17 00:00:00 2001 From: Max Huang-Hobbs Date: Thu, 4 Apr 2024 21:11:33 +0000 Subject: [PATCH 1/3] handle []T where T implements json.Unmarshal --- musttag.go | 34 +++++++++++++++++++++++----------- testdata/src/tests/builtins.go | 11 +++++++++++ 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/musttag.go b/musttag.go index c147ec4..614b226 100644 --- a/musttag.go +++ b/musttag.go @@ -149,31 +149,43 @@ func (c *checker) checkType(typ types.Type, tag string) bool { } c.seenTypes[typ.String()] = struct{}{} - if implementsInterface(typ, c.ifaceWhitelist, c.imports) { - return true // the type implements a Marshaler interface; see issue #64. - } - - styp, ok := c.parseStruct(typ) - if !ok { + styp, shouldCheckStruct := c.unwrapStructType(typ) + if !shouldCheckStruct { return true // not a struct. } return c.checkStruct(styp, tag) } -func (c *checker) parseStruct(typ types.Type) (*types.Struct, bool) { +// recursively unpack a type to the next type that needs checking, +// until we get to some underlying struct for checking +// +// SomeStruct -> struct{SomeStructField: ... } +// []*SomeStruct -> struct{SomeStructField: ... } +// ... +// +// exits early if it hits a type that implements a whitelisted interface +func (c *checker) unwrapStructType(typ types.Type) (*types.Struct, bool) { + fmt.Println("unwrapStructType", typ, c.ifaceWhitelist, c.imports) + + if implementsInterface(typ, c.ifaceWhitelist, c.imports) { + fmt.Println(" - unwrapStructType exit early implements!") + return nil, false // the type implements a Marshaler interface; see issue #64. + } + switch typ := typ.(type) { case *types.Pointer: - return c.parseStruct(typ.Elem()) + return c.unwrapStructType(typ.Elem()) case *types.Array: - return c.parseStruct(typ.Elem()) + return c.unwrapStructType(typ.Elem()) case *types.Slice: - return c.parseStruct(typ.Elem()) + fmt.Println("elem!:", typ.Elem()) + return c.unwrapStructType(typ.Elem()) case *types.Map: - return c.parseStruct(typ.Elem()) + return c.unwrapStructType(typ.Elem()) case *types.Named: // a struct of the named type. pkg := typ.Obj().Pkg() diff --git a/testdata/src/tests/builtins.go b/testdata/src/tests/builtins.go index d576da4..9d51805 100644 --- a/testdata/src/tests/builtins.go +++ b/testdata/src/tests/builtins.go @@ -51,6 +51,17 @@ func testJSON() { json.NewDecoder(nil).Decode(&tm) } +func testJSONIndirectSlice() { + type WithMarshallableSlice struct { + List []Marshaler `json:"marshallable"` + } + var withMarshallableSlice WithMarshallableSlice + + json.Marshal(withMarshallableSlice) + json.MarshalIndent(withMarshallableSlice, "", "") + json.NewEncoder(nil).Encode(withMarshallableSlice) +} + func testXML() { var st Struct xml.Marshal(st) // want "the given struct should be annotated with the `xml` tag" From 943ef0d605f440970a66a035db904d70312c3125 Mon Sep 17 00:00:00 2001 From: Max Huang-Hobbs Date: Fri, 5 Apr 2024 11:44:33 +0000 Subject: [PATCH 2/3] address comments --- musttag.go | 21 +++++++++------------ testdata/src/tests/builtins.go | 11 ----------- testdata/src/tests/tests.go | 11 +++++++++++ 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/musttag.go b/musttag.go index 614b226..7a794c1 100644 --- a/musttag.go +++ b/musttag.go @@ -149,43 +149,40 @@ func (c *checker) checkType(typ types.Type, tag string) bool { } c.seenTypes[typ.String()] = struct{}{} - styp, shouldCheckStruct := c.unwrapStructType(typ) - if !shouldCheckStruct { + styp, ok := c.parseStruct(typ) + if !ok { return true // not a struct. } return c.checkStruct(styp, tag) } -// recursively unpack a type to the next type that needs checking, -// until we get to some underlying struct for checking +// recursively unwrap a type until we get to an underlying +// raw struct type that should have its fields checked // // SomeStruct -> struct{SomeStructField: ... } // []*SomeStruct -> struct{SomeStructField: ... } // ... // // exits early if it hits a type that implements a whitelisted interface -func (c *checker) unwrapStructType(typ types.Type) (*types.Struct, bool) { - fmt.Println("unwrapStructType", typ, c.ifaceWhitelist, c.imports) +func (c *checker) parseStruct(typ types.Type) (*types.Struct, bool) { if implementsInterface(typ, c.ifaceWhitelist, c.imports) { - fmt.Println(" - unwrapStructType exit early implements!") return nil, false // the type implements a Marshaler interface; see issue #64. } switch typ := typ.(type) { case *types.Pointer: - return c.unwrapStructType(typ.Elem()) + return c.parseStruct(typ.Elem()) case *types.Array: - return c.unwrapStructType(typ.Elem()) + return c.parseStruct(typ.Elem()) case *types.Slice: - fmt.Println("elem!:", typ.Elem()) - return c.unwrapStructType(typ.Elem()) + return c.parseStruct(typ.Elem()) case *types.Map: - return c.unwrapStructType(typ.Elem()) + return c.parseStruct(typ.Elem()) case *types.Named: // a struct of the named type. pkg := typ.Obj().Pkg() diff --git a/testdata/src/tests/builtins.go b/testdata/src/tests/builtins.go index 9d51805..d576da4 100644 --- a/testdata/src/tests/builtins.go +++ b/testdata/src/tests/builtins.go @@ -51,17 +51,6 @@ func testJSON() { json.NewDecoder(nil).Decode(&tm) } -func testJSONIndirectSlice() { - type WithMarshallableSlice struct { - List []Marshaler `json:"marshallable"` - } - var withMarshallableSlice WithMarshallableSlice - - json.Marshal(withMarshallableSlice) - json.MarshalIndent(withMarshallableSlice, "", "") - json.NewEncoder(nil).Encode(withMarshallableSlice) -} - func testXML() { var st Struct xml.Marshal(st) // want "the given struct should be annotated with the `xml` tag" diff --git a/testdata/src/tests/tests.go b/testdata/src/tests/tests.go index 44266a8..7bcc489 100644 --- a/testdata/src/tests/tests.go +++ b/testdata/src/tests/tests.go @@ -164,3 +164,14 @@ func ignoredNestedType() { json.Marshal(Foo{}) // no error json.Marshal(&Foo{}) // no error } + +func interfaceSliceType() { + type WithMarshallableSlice struct { + List []Marshaler `json:"marshallable"` + } + var withMarshallableSlice WithMarshallableSlice + + json.Marshal(withMarshallableSlice) + json.MarshalIndent(withMarshallableSlice, "", "") + json.NewEncoder(nil).Encode(withMarshallableSlice) +} From 47c706f0410d92133fe4fc7539235795df9b45a0 Mon Sep 17 00:00:00 2001 From: Tom <73077675+tmzane@users.noreply.github.com> Date: Fri, 5 Apr 2024 18:10:29 +0500 Subject: [PATCH 3/3] fix linter complain --- musttag.go | 1 - 1 file changed, 1 deletion(-) diff --git a/musttag.go b/musttag.go index 7a794c1..c4f4f7d 100644 --- a/musttag.go +++ b/musttag.go @@ -166,7 +166,6 @@ func (c *checker) checkType(typ types.Type, tag string) bool { // // exits early if it hits a type that implements a whitelisted interface func (c *checker) parseStruct(typ types.Type) (*types.Struct, bool) { - if implementsInterface(typ, c.ifaceWhitelist, c.imports) { return nil, false // the type implements a Marshaler interface; see issue #64. }