diff --git a/cmd/linters/main.go b/cmd/linters/main.go index 37819428604..f3d4668aacb 100644 --- a/cmd/linters/main.go +++ b/cmd/linters/main.go @@ -17,6 +17,7 @@ import ( "golang.org/x/tools/go/analysis/multichecker" "github.com/github/gh-aw/pkg/linters/ctxbackground" + "github.com/github/gh-aw/pkg/linters/errstringmatch" "github.com/github/gh-aw/pkg/linters/excessivefuncparams" "github.com/github/gh-aw/pkg/linters/largefunc" "github.com/github/gh-aw/pkg/linters/osexitinlibrary" @@ -27,6 +28,7 @@ import ( func main() { multichecker.Main( ctxbackground.Analyzer, + errstringmatch.Analyzer, excessivefuncparams.Analyzer, largefunc.Analyzer, osexitinlibrary.Analyzer, diff --git a/pkg/linters/errstringmatch/errstringmatch.go b/pkg/linters/errstringmatch/errstringmatch.go new file mode 100644 index 00000000000..193282890f1 --- /dev/null +++ b/pkg/linters/errstringmatch/errstringmatch.go @@ -0,0 +1,135 @@ +// Package errstringmatch implements a Go analysis linter that flags +// calls to strings.Contains(err.Error(), "literal") that perform brittle +// substring matching on error messages instead of using errors.Is or errors.As. +package errstringmatch + +import ( + "go/ast" + "go/token" + "go/types" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/ast/inspector" +) + +// Analyzer is the err-string-match analysis pass. +var Analyzer = &analysis.Analyzer{ + Name: "errstringmatch", + Doc: "reports strings.Contains(err.Error(), \"...\") calls that perform brittle substring matching on error messages", + URL: "https://github.com/github/gh-aw/tree/main/pkg/linters/errstringmatch", + Requires: []*analysis.Analyzer{inspect.Analyzer}, + Run: run, +} + +func run(pass *analysis.Pass) (any, error) { + insp := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) + + nodeFilter := []ast.Node{ + (*ast.CallExpr)(nil), + } + + insp.Preorder(nodeFilter, func(n ast.Node) { + outer, ok := n.(*ast.CallExpr) + if !ok { + return + } + + // Match strings.Contains(X, Y) + if !isStringsContains(outer) { + return + } + if len(outer.Args) != 2 { + return + } + + // First arg must be a call to err.Error() + if !isErrDotError(pass, outer.Args[0]) { + return + } + + // Second arg must be a string literal (or at least a string type) + if !isStringLiteral(pass, outer.Args[1]) { + return + } + + pass.Reportf(outer.Pos(), "avoid strings.Contains(err.Error(), ...) — use errors.Is, errors.As, or a sentinel error instead") + }) + + return nil, nil +} + +// isStringsContains returns true for strings.Contains(...) call expressions. +func isStringsContains(call *ast.CallExpr) bool { + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return false + } + ident, ok := sel.X.(*ast.Ident) + if !ok { + return false + } + return ident.Name == "strings" && sel.Sel.Name == "Contains" +} + +// isErrDotError returns true when expr is a method call of the form .Error() +// where the receiver implements the error interface. +func isErrDotError(pass *analysis.Pass, expr ast.Expr) bool { + call, ok := expr.(*ast.CallExpr) + if !ok { + return false + } + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return false + } + if sel.Sel.Name != "Error" { + return false + } + if len(call.Args) != 0 { + return false + } + // Check that the receiver implements the error interface. + t := pass.TypesInfo.TypeOf(sel.X) + if t == nil { + return false + } + return implementsError(pass, t) +} + +// implementsError reports whether t implements the built-in error interface. +func implementsError(pass *analysis.Pass, t types.Type) bool { + errIface := pass.Pkg.Scope().Lookup("error") + if errIface == nil { + // Look up the universe scope. + obj := types.Universe.Lookup("error") + if obj == nil { + return false + } + iface, ok := obj.Type().Underlying().(*types.Interface) + if !ok { + return false + } + return types.Implements(t, iface) || types.Implements(types.NewPointer(t), iface) + } + iface, ok := errIface.Type().Underlying().(*types.Interface) + if !ok { + return false + } + return types.Implements(t, iface) || types.Implements(types.NewPointer(t), iface) +} + +// isStringLiteral returns true when expr is a string literal or untyped string constant. +func isStringLiteral(pass *analysis.Pass, expr ast.Expr) bool { + lit, ok := expr.(*ast.BasicLit) + if ok && lit.Kind == token.STRING { + return true + } + // Also accept typed/untyped string constants (e.g. a const identifier). + t := pass.TypesInfo.TypeOf(expr) + if t == nil { + return false + } + basic, ok := t.Underlying().(*types.Basic) + return ok && basic.Kind() == types.String +} diff --git a/pkg/linters/errstringmatch/errstringmatch_test.go b/pkg/linters/errstringmatch/errstringmatch_test.go new file mode 100644 index 00000000000..1aa0c1d1b92 --- /dev/null +++ b/pkg/linters/errstringmatch/errstringmatch_test.go @@ -0,0 +1,16 @@ +//go:build !integration + +package errstringmatch_test + +import ( + "testing" + + "golang.org/x/tools/go/analysis/analysistest" + + "github.com/github/gh-aw/pkg/linters/errstringmatch" +) + +func TestErrStringMatch(t *testing.T) { + testdata := analysistest.TestData() + analysistest.Run(t, testdata, errstringmatch.Analyzer, "errstringmatch") +} diff --git a/pkg/linters/errstringmatch/testdata/src/errstringmatch/errstringmatch.go b/pkg/linters/errstringmatch/testdata/src/errstringmatch/errstringmatch.go new file mode 100644 index 00000000000..5248ce95264 --- /dev/null +++ b/pkg/linters/errstringmatch/testdata/src/errstringmatch/errstringmatch.go @@ -0,0 +1,28 @@ +package errstringmatch + +import ( + "errors" + "strings" +) + +var errNotFound = errors.New("not found") + +// flagged: strings.Contains on err.Error() with a string literal +func checkError(err error) bool { + return strings.Contains(err.Error(), "not found") // want `avoid strings\.Contains\(err\.Error\(\)` +} + +// flagged: same pattern with a different variable name +func checkPermission(e error) bool { + return strings.Contains(e.Error(), "403") // want `avoid strings\.Contains\(err\.Error\(\)` +} + +// not flagged: using errors.Is +func checkErrorSafe(err error) bool { + return errors.Is(err, errNotFound) +} + +// not flagged: strings.Contains on a plain string, not err.Error() +func checkString(s string) bool { + return strings.Contains(s, "prefix") +}