Skip to content

Commit

Permalink
fix: do not report types implementing (Un)Marshaler (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmzane committed Oct 8, 2023
1 parent 11f0e6c commit 1105e1a
Show file tree
Hide file tree
Showing 3 changed files with 265 additions and 26 deletions.
154 changes: 132 additions & 22 deletions builtins.go
Expand Up @@ -3,34 +3,144 @@ package musttag
// builtins is a set of functions supported out of the box.
var builtins = []Func{
// https://pkg.go.dev/encoding/json
{Name: "encoding/json.Marshal", Tag: "json", ArgPos: 0},
{Name: "encoding/json.MarshalIndent", Tag: "json", ArgPos: 0},
{Name: "encoding/json.Unmarshal", Tag: "json", ArgPos: 1},
{Name: "(*encoding/json.Encoder).Encode", Tag: "json", ArgPos: 0},
{Name: "(*encoding/json.Decoder).Decode", Tag: "json", ArgPos: 0},
{
Name: "encoding/json.Marshal",
Tag: "json",
ArgPos: 0,
ifaceWhitelist: []string{"encoding/json.Marshaler", "encoding.TextMarshaler"},
},
{
Name: "encoding/json.MarshalIndent",
Tag: "json",
ArgPos: 0,
ifaceWhitelist: []string{"encoding/json.Marshaler", "encoding.TextMarshaler"},
},
{
Name: "encoding/json.Unmarshal",
Tag: "json",
ArgPos: 1,
ifaceWhitelist: []string{"encoding/json.Unmarshaler", "encoding.TextUnmarshaler"},
},
{
Name: "(*encoding/json.Encoder).Encode",
Tag: "json",
ArgPos: 0,
ifaceWhitelist: []string{"encoding/json.Marshaler", "encoding.TextMarshaler"},
},
{
Name: "(*encoding/json.Decoder).Decode",
Tag: "json",
ArgPos: 0,
ifaceWhitelist: []string{"encoding/json.Unmarshaler", "encoding.TextUnmarshaler"},
},

// https://pkg.go.dev/encoding/xml
{Name: "encoding/xml.Marshal", Tag: "xml", ArgPos: 0},
{Name: "encoding/xml.MarshalIndent", Tag: "xml", ArgPos: 0},
{Name: "encoding/xml.Unmarshal", Tag: "xml", ArgPos: 1},
{Name: "(*encoding/xml.Encoder).Encode", Tag: "xml", ArgPos: 0},
{Name: "(*encoding/xml.Decoder).Decode", Tag: "xml", ArgPos: 0},
{Name: "(*encoding/xml.Encoder).EncodeElement", Tag: "xml", ArgPos: 0},
{Name: "(*encoding/xml.Decoder).DecodeElement", Tag: "xml", ArgPos: 0},
{
Name: "encoding/xml.Marshal",
Tag: "xml",
ArgPos: 0,
ifaceWhitelist: []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"},
},
{
Name: "encoding/xml.MarshalIndent",
Tag: "xml",
ArgPos: 0,
ifaceWhitelist: []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"},
},
{
Name: "encoding/xml.Unmarshal",
Tag: "xml",
ArgPos: 1,
ifaceWhitelist: []string{"encoding/xml.Unmarshaler", "encoding.TextUnmarshaler"},
},
{
Name: "(*encoding/xml.Encoder).Encode",
Tag: "xml",
ArgPos: 0,
ifaceWhitelist: []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"},
},
{
Name: "(*encoding/xml.Decoder).Decode",
Tag: "xml",
ArgPos: 0,
ifaceWhitelist: []string{"encoding/xml.Unmarshaler", "encoding.TextUnmarshaler"},
},
{
Name: "(*encoding/xml.Encoder).EncodeElement",
Tag: "xml",
ArgPos: 0,
ifaceWhitelist: []string{"encoding/xml.Marshaler", "encoding.TextMarshaler"},
},
{
Name: "(*encoding/xml.Decoder).DecodeElement",
Tag: "xml",
ArgPos: 0,
ifaceWhitelist: []string{"encoding/xml.Unmarshaler", "encoding.TextUnmarshaler"},
},

// https://github.com/go-yaml/yaml
{Name: "gopkg.in/yaml.v3.Marshal", Tag: "yaml", ArgPos: 0},
{Name: "gopkg.in/yaml.v3.Unmarshal", Tag: "yaml", ArgPos: 1},
{Name: "(*gopkg.in/yaml.v3.Encoder).Encode", Tag: "yaml", ArgPos: 0},
{Name: "(*gopkg.in/yaml.v3.Decoder).Decode", Tag: "yaml", ArgPos: 0},
{
Name: "gopkg.in/yaml.v3.Marshal",
Tag: "yaml",
ArgPos: 0,
ifaceWhitelist: []string{"gopkg.in/yaml.v3.Marshaler"},
},
{
Name: "gopkg.in/yaml.v3.Unmarshal",
Tag: "yaml",
ArgPos: 1,
ifaceWhitelist: []string{"gopkg.in/yaml.v3.Unmarshaler"},
},
{
Name: "(*gopkg.in/yaml.v3.Encoder).Encode",
Tag: "yaml",
ArgPos: 0,
ifaceWhitelist: []string{"gopkg.in/yaml.v3.Marshaler"},
},
{
Name: "(*gopkg.in/yaml.v3.Decoder).Decode",
Tag: "yaml",
ArgPos: 0,
ifaceWhitelist: []string{"gopkg.in/yaml.v3.Unmarshaler"},
},

// https://github.com/BurntSushi/toml
{Name: "github.com/BurntSushi/toml.Unmarshal", Tag: "toml", ArgPos: 1},
{Name: "github.com/BurntSushi/toml.Decode", Tag: "toml", ArgPos: 1},
{Name: "github.com/BurntSushi/toml.DecodeFS", Tag: "toml", ArgPos: 2},
{Name: "github.com/BurntSushi/toml.DecodeFile", Tag: "toml", ArgPos: 1},
{Name: "(*github.com/BurntSushi/toml.Encoder).Encode", Tag: "toml", ArgPos: 0},
{Name: "(*github.com/BurntSushi/toml.Decoder).Decode", Tag: "toml", ArgPos: 0},
{
Name: "github.com/BurntSushi/toml.Unmarshal",
Tag: "toml",
ArgPos: 1,
ifaceWhitelist: []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"},
},
{
Name: "github.com/BurntSushi/toml.Decode",
Tag: "toml",
ArgPos: 1,
ifaceWhitelist: []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"},
},
{
Name: "github.com/BurntSushi/toml.DecodeFS",
Tag: "toml",
ArgPos: 2,
ifaceWhitelist: []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"},
},
{
Name: "github.com/BurntSushi/toml.DecodeFile",
Tag: "toml",
ArgPos: 1,
ifaceWhitelist: []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"},
},
{
Name: "(*github.com/BurntSushi/toml.Encoder).Encode",
Tag: "toml",
ArgPos: 0,
ifaceWhitelist: []string{"encoding.TextMarshaler"},
},
{
Name: "(*github.com/BurntSushi/toml.Decoder).Decode",
Tag: "toml",
ArgPos: 0,
ifaceWhitelist: []string{"github.com/BurntSushi/toml.Unmarshaler", "encoding.TextUnmarshaler"},
},

// https://github.com/mitchellh/mapstructure
{Name: "github.com/mitchellh/mapstructure.Decode", Tag: "mapstructure", ArgPos: 1},
Expand Down
67 changes: 63 additions & 4 deletions musttag.go
Expand Up @@ -24,6 +24,10 @@ type Func struct {
Name string // Name is the full name of the function, including the package.
Tag string // Tag is the struct tag whose presence should be ensured.
ArgPos int // ArgPos is the position of the argument to check.

// a list of interface names (including the package);
// if at least one is implemented by the argument, no check is performed.
ifaceWhitelist []string
}

func (fn Func) shortName() string {
Expand Down Expand Up @@ -93,7 +97,7 @@ var report = func(pass *analysis.Pass, st *structType, fn Func, fnPos token.Posi
pass.Reportf(st.Pos, format, st.Name, fn.Tag, fn.shortName(), fnPos)
}

var cleanFullName = regexp.MustCompile(`([^*/(]+/vendor/)`)
var trimVendor = regexp.MustCompile(`([^*/(]+/vendor/)`)

// run starts the analysis.
func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (any, error) {
Expand All @@ -117,7 +121,7 @@ func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (any, er
return // not a static call.
}

name := cleanFullName.ReplaceAllString(callee.FullName(), "")
name := trimVendor.ReplaceAllString(callee.FullName(), "")
fn, ok := funcs[name]
if !ok {
return // the function is not supported.
Expand All @@ -144,13 +148,21 @@ func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (any, er
initialPos = arg.Pos()
}

typ := pass.TypesInfo.TypeOf(arg)
if typ == nil {
return // no type info found.
}

if implementsInterface(typ, fn.ifaceWhitelist, pass.Pkg.Imports()) {
return // the type implements a Marshaler interface, nothing to check; see issue #64.
}

checker := checker{
mainModule: mainModule,
seenTypes: make(map[string]struct{}),
}

t := pass.TypesInfo.TypeOf(arg)
st, ok := checker.parseStructType(t, initialPos)
st, ok := checker.parseStructType(typ, initialPos)
if !ok {
return // not a struct argument.
}
Expand Down Expand Up @@ -257,3 +269,50 @@ func (c *checker) checkStructType(st *structType, tag string) (*structType, bool

return nil, true
}

func implementsInterface(typ types.Type, ifaces []string, imports []*types.Package) bool {
findScope := func(pkgName string) (*types.Scope, bool) {
// fast path: check direct imports (e.g. looking for "encoding/json.Marshaler").
for _, direct := range imports {
if pkgName == trimVendor.ReplaceAllString(direct.Path(), "") {
return direct.Scope(), true
}
}
// slow path: check indirect imports (e.g. looking for "encoding.TextMarshaler").
for _, direct := range imports {
for _, indirect := range direct.Imports() {
if pkgName == trimVendor.ReplaceAllString(indirect.Path(), "") {
return indirect.Scope(), true
}
}
}
return nil, false
}

for _, ifacePath := range ifaces {
// "encoding/json.Marshaler" -> "encoding/json" + "Marshaler"
idx := strings.LastIndex(ifacePath, ".")
if idx == -1 {
continue
}
pkgName, ifaceName := ifacePath[:idx], ifacePath[idx+1:]

scope, ok := findScope(pkgName)
if !ok {
continue
}
obj := scope.Lookup(ifaceName)
if obj == nil {
continue
}
iface, ok := obj.Type().Underlying().(*types.Interface)
if !ok {
continue
}
if types.Implements(typ, iface) {
return true
}
}

return false
}
70 changes: 70 additions & 0 deletions testdata/src/builtins/builtins.go
Expand Up @@ -76,13 +76,44 @@ type User struct { /* want
Email string `json:"email" xml:"email" yaml:"email" toml:"email" mapstructure:"email" db:"email" custom:"email"`
}

// TODO: Unmarshaler should be implemented using pointer semantics.

type TextMarshaler struct{ NoTag string }

func (TextMarshaler) MarshalText() ([]byte, error) { return nil, nil }
func (TextMarshaler) UnmarshalText([]byte) error { return nil }

type Marshaler struct{ NoTag string }

func (Marshaler) MarshalJSON() ([]byte, error) { return nil, nil }
func (Marshaler) UnmarshalJSON([]byte) error { return nil }
func (Marshaler) MarshalXML(e *xml.Encoder, start xml.StartElement) error { return nil }
func (Marshaler) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { return nil }
func (Marshaler) MarshalYAML() (any, error) { return nil, nil }
func (Marshaler) UnmarshalYAML(*yaml.Node) error { return nil }
func (Marshaler) UnmarshalTOML(any) error { return nil }

func testJSON() {
var user User
json.Marshal(user)
json.MarshalIndent(user, "", "")
json.Unmarshal(nil, &user)
json.NewEncoder(nil).Encode(user)
json.NewDecoder(nil).Decode(&user)

var m Marshaler
json.Marshal(m)
json.MarshalIndent(m, "", "")
json.Unmarshal(nil, &m)
json.NewEncoder(nil).Encode(m)
json.NewDecoder(nil).Decode(&m)

var tm TextMarshaler
json.Marshal(tm)
json.MarshalIndent(tm, "", "")
json.Unmarshal(nil, &tm)
json.NewEncoder(nil).Encode(tm)
json.NewDecoder(nil).Decode(&tm)
}

func testXML() {
Expand All @@ -94,6 +125,24 @@ func testXML() {
xml.NewDecoder(nil).Decode(&user)
xml.NewEncoder(nil).EncodeElement(user, xml.StartElement{})
xml.NewDecoder(nil).DecodeElement(&user, &xml.StartElement{})

var m Marshaler
xml.Marshal(m)
xml.MarshalIndent(m, "", "")
xml.Unmarshal(nil, &m)
xml.NewEncoder(nil).Encode(m)
xml.NewDecoder(nil).Decode(&m)
xml.NewEncoder(nil).EncodeElement(m, xml.StartElement{})
xml.NewDecoder(nil).DecodeElement(&m, &xml.StartElement{})

var tm TextMarshaler
xml.Marshal(tm)
xml.MarshalIndent(tm, "", "")
xml.Unmarshal(nil, &tm)
xml.NewEncoder(nil).Encode(tm)
xml.NewDecoder(nil).Decode(&tm)
xml.NewEncoder(nil).EncodeElement(tm, xml.StartElement{})
xml.NewDecoder(nil).DecodeElement(&tm, &xml.StartElement{})
}

func testYAML() {
Expand All @@ -102,6 +151,12 @@ func testYAML() {
yaml.Unmarshal(nil, &user)
yaml.NewEncoder(nil).Encode(user)
yaml.NewDecoder(nil).Decode(&user)

var m Marshaler
yaml.Marshal(m)
yaml.Unmarshal(nil, &m)
yaml.NewEncoder(nil).Encode(m)
yaml.NewDecoder(nil).Decode(&m)
}

func testTOML() {
Expand All @@ -112,6 +167,21 @@ func testTOML() {
toml.DecodeFile("", &user)
toml.NewEncoder(nil).Encode(user)
toml.NewDecoder(nil).Decode(&user)

var m Marshaler
toml.Unmarshal(nil, &m)
toml.Decode("", &m)
toml.DecodeFS(nil, "", &m)
toml.DecodeFile("", &m)
toml.NewDecoder(nil).Decode(&m)

var tm TextMarshaler
toml.Unmarshal(nil, &tm)
toml.Decode("", &tm)
toml.DecodeFS(nil, "", &tm)
toml.DecodeFile("", &tm)
toml.NewEncoder(nil).Encode(tm)
toml.NewDecoder(nil).Decode(&tm)
}

func testMapstructure() {
Expand Down

0 comments on commit 1105e1a

Please sign in to comment.