From 1bc8c3c28bb9f75e7eff8457440fe6fd913d33e5 Mon Sep 17 00:00:00 2001 From: Oleg Butuzov Date: Sat, 6 May 2023 16:15:17 +0300 Subject: [PATCH] refactor: conversation type checks - removed redundant `types` - simplified check flow --- internal/checker/checker.go | 43 +++++++++++++++++++++++------------ internal/checker/types.go | 30 ------------------------ internal/checker/violation.go | 6 ++--- 3 files changed, 31 insertions(+), 48 deletions(-) delete mode 100644 internal/checker/types.go diff --git a/internal/checker/checker.go b/internal/checker/checker.go index 49209ae..3b41dd9 100644 --- a/internal/checker/checker.go +++ b/internal/checker/checker.go @@ -34,8 +34,15 @@ func (c *Checker) Check(e *ast.CallExpr) *Violation { // Regular calls (`*ast.SelectorExpr`) like strings.HasPrefix or re.Match are // handled by this check case *ast.SelectorExpr: + x, ok := expr.X.(*ast.Ident) + // TODO: add check for the ast.ParenExpr in e.Fun so we can target + // the constructions like this + // + // (&maphash.Hash{}).Write([]byte("foobar")) + // + if !ok { return nil // can't be mached, so can't be checked. } @@ -63,7 +70,6 @@ func (c *Checker) Check(e *ast.CallExpr) *Violation { return v.WithAltArgs(argsFixed) } } - } return nil } @@ -71,25 +77,24 @@ func (c *Checker) Check(e *ast.CallExpr) *Violation { func (c *Checker) handleViolation(v *Violation, ce *ast.CallExpr) (map[int]ast.Expr, bool) { m := map[int]ast.Expr{} + // We going to check each of elements we mark for checking, in order to find, + // a call that violates our rules. for _, i := range v.Args { if i >= len(ce.Args) { continue } call, ok := ce.Args[i].(*ast.CallExpr) - if !ok { + if !ok || !c.isConverterCall(call) || len(call.Args) == 0 { continue } - if t := c.Type(call); t.String() != "string" && t.String() != "[]byte" { - continue - } - - if string(v.Targets()) != c.types.TypeOf(call.Args[0]).String() { + // checking whats argument + if v.Targets() != c.Type(call.Args[0]).String() { m[i] = call.Args[0] } - } + return m, len(m) == len(v.Args) } @@ -110,17 +115,14 @@ func (c *Checker) HandleFunction(pkgName, methodName string) *Violation { } func (c *Checker) HandleMethod(receiver ast.Expr, method string) *Violation { - if c.types == nil || !c.types.Types[receiver].IsValue() { + if c.types == nil { return nil } - tv := c.types.Types[receiver] - if tv.Type == nil { - // todo(butuzov): logError + tv := c.types.Types[receiver] + if !tv.IsValue() || tv.Type == nil { return nil - } - - if methods, ok := c.Methods[cleanName(tv.Type.String())]; !ok { + } else if methods, ok := c.Methods[cleanName(tv.Type.String())]; !ok { return nil } else if violation, ok := methods[method]; ok { return &violation @@ -144,6 +146,17 @@ func (c *Checker) isImported(pkg, name string) bool { return false } +// todo: not implemented +func (c *Checker) isConverterCall(ce *ast.CallExpr) bool { + switch ce.Fun.(type) { + case *ast.ArrayType: + return c.types.TypeOf(ce.Fun).String() == "[]byte" + case *ast.Ident: + return c.types.TypeOf(ce.Fun).String() == "string" + } + return false +} + // cleanName will remove * from the name variable if it is a pointer. func cleanName(name string) string { if name[0] == '*' { diff --git a/internal/checker/types.go b/internal/checker/types.go deleted file mode 100644 index ad66581..0000000 --- a/internal/checker/types.go +++ /dev/null @@ -1,30 +0,0 @@ -package checker - -type Type string - -const ( - String Type = "string" - Bytes Type = "[]byte" -) - -type ReturnType int - -const ( - typeUnknown ReturnType = iota - typeByte - typeString - typeByteSlice -) - -func (t ReturnType) String() string { - switch t { - case typeByte: - return `byte` - case typeString: - return `string` - case typeByteSlice: - return `[]byte` - } - - return "unknown" -} diff --git a/internal/checker/violation.go b/internal/checker/violation.go index 52ef7d7..4bd9de2 100644 --- a/internal/checker/violation.go +++ b/internal/checker/violation.go @@ -51,12 +51,12 @@ func (v *Violation) Handle(ce *ast.CallExpr) (m map[int]ast.Expr, ok bool) { return m, len(m) == len(v.Args) } -func (v *Violation) Targets() Type { +func (v *Violation) Targets() string { if !v.StringTargeted { - return Bytes + return "[]byte" } - return String + return "string" } // TODO: not implemented