From e95f04b005014c64a78ec2228e60b2fb0fec8f00 Mon Sep 17 00:00:00 2001 From: Fatih Arslan Date: Sat, 30 Nov 2019 11:45:39 -0800 Subject: [PATCH] Add source code with updated README and License --- LICENSE | 33 +++ README.md | 61 ++++- cmd/errwrap/main.go | 10 + go.mod | 5 + go.sum | 8 + internal/errwrap/errwrap.go | 487 ++++++++++++++++++++++++++++++++++++ 6 files changed, 603 insertions(+), 1 deletion(-) create mode 100644 cmd/errwrap/main.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/errwrap/errwrap.go diff --git a/LICENSE b/LICENSE index 85025b6..5730ddd 100644 --- a/LICENSE +++ b/LICENSE @@ -27,3 +27,36 @@ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +This software includes some portions from Go. Go is used under the terms of the +BSD like license. + +Copyright (c) 2012 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +The Go gopher was designed by Renee French. http://reneefrench.blogspot.com/ The design is licensed under the Creative Commons 3.0 Attributions license. Read this article for more details: https://blog.golang.org/gopher diff --git a/README.md b/README.md index 47fb604..c28e96f 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,61 @@ # errwrap -Wrap and fix Go errors with the new %w verb directive + +Wrap and fix Go errors with the new %w verb directive. This tool analyzes +`fmt.Errorf()` calls and reports calls that contain a verb directive that is +different than the new `%w` verb directive [introduced in Go v1.13](https://golang.org/doc/go1.13#error_wrapping). It's also capable of rewriting calls to use the new `%w` wrap verb directive. + + +# Install + +```bash +go get github.com/fatih/errwrap/cmd/errwrap +``` + +# Usage + +By default, `errwrap` prints the output of the analyzer to stdout. You can pass +a file, directory or a Go package: + +```sh +$ errwrap foo.go # pass a file +$ errwrap ./... # recursively analyze all files +$ errwrap github.com/fatih/gomodifytags # or pass a package +``` + +When called it displays the error with the line and column: + +``` +gomodifytags@v1.0.1/main.go:200:16: call could wrap the error with error-wrapping directive %w +gomodifytags@v1.0.1/main.go:641:17: call could wrap the error with error-wrapping directive %w +gomodifytags@v1.0.1/main.go:749:15: call could wrap the error with error-wrapping directive %w +``` + +`errwrap` is also able to rewrite your source code to replace any verb +directive used for an `error` type with the `%w` verb directive. Assume we have +the following source code: + +``` +$ cat demo.go +package main + +import ( + "errors" + "fmt" +) + +func main() { + _ = foo() +} + +func foo() error { + err := errors.New("bar!") + return fmt.Errorf("foo failed: %s: %w bar ...", "foo", err) +} +``` + +Calling `errwrap` with the `-fix` flag will rewrite the source code: + +``` +$ errwrap -fix example.go +demo.go:14:9: call could wrap the error with error-wrapping directive %w +``` diff --git a/cmd/errwrap/main.go b/cmd/errwrap/main.go new file mode 100644 index 0000000..63295db --- /dev/null +++ b/cmd/errwrap/main.go @@ -0,0 +1,10 @@ +package main + +import ( + "github.com/fatih/errwrap/internal/errwrap" + "golang.org/x/tools/go/analysis/singlechecker" +) + +func main() { + singlechecker.Main(errwrap.Analyzer) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..9226712 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module github.com/fatih/errwrap + +go 1.13 + +require golang.org/x/tools v0.0.0-20191014205221-18e3458ac98b diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e00f8d4 --- /dev/null +++ b/go.sum @@ -0,0 +1,8 @@ +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20191014205221-18e3458ac98b h1:EsQHTYgcM562dq02r6y2Yt9VpvvLNIyNECx96XQeolA= +golang.org/x/tools v0.0.0-20191014205221-18e3458ac98b/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/errwrap/errwrap.go b/internal/errwrap/errwrap.go new file mode 100644 index 0000000..2372b5d --- /dev/null +++ b/internal/errwrap/errwrap.go @@ -0,0 +1,487 @@ +// Package errwrap defines an Analyzer that rewrites error statements to use the +// new wrapping/unwrapping functionality +package errwrap + +import ( + "bytes" + "fmt" + "go/ast" + "go/constant" + "go/printer" + "go/token" + "go/types" + "strconv" + "strings" + "unicode/utf8" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/ast/inspector" + "golang.org/x/tools/go/types/typeutil" +) + +var Analyzer = &analysis.Analyzer{ + Name: "errwrap", + Doc: "wrap errors in fmt.Errorf() calls with the %w verb directive", + Requires: []*analysis.Analyzer{inspect.Analyzer}, + Run: run, + RunDespiteErrors: true, +} + +func run(pass *analysis.Pass) (interface{}, error) { + inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) + + nodeFilter := []ast.Node{ + (*ast.CallExpr)(nil), + } + + inspect.Preorder(nodeFilter, func(n ast.Node) { + call := n.(*ast.CallExpr) + + fn, _ := typeutil.Callee(pass.TypesInfo, call).(*types.Func) + if fn == nil { + return + } + + // for now only check these functions + if fn.FullName() != "fmt.Errorf" { + return + } + + oldExpr := render(pass.Fset, call) + + format, idx := formatString(pass, call) + if idx < 0 { + // call has arguments but no formatting directives + return + } + + firstArg := idx + 1 // Arguments are immediately after format string. + if !strings.Contains(format, "%") { + if len(call.Args) > firstArg { + pass.Reportf(call.Lparen, "%s call has arguments but no formatting directives", fn.Name()) + } + return + } + + var hasError bool + var errIndex int + for i, arg := range call.Args { + if t := pass.TypesInfo.TypeOf(arg); t != nil { + if t.String() == "error" { + hasError = true + errIndex = i + } + } + } + + if !hasError { + return + } + + argNum := firstArg + maxArgNum := firstArg + anyIndex := false + anyW := false + newFormat := []byte(format) + for i, w := 0, 0; i < len(format); i += w { + w = 1 + if format[i] != '%' { + continue + } + + state := parsePrintfVerb(pass, call, fn.Name(), format[i:], firstArg, argNum) + if state == nil { + return + } + + w = len(state.format) + if state.hasIndex { + anyIndex = true + } + + if len(state.argNums) > 0 { + // Continue with the next sequential argument. + argNum = state.argNums[len(state.argNums)-1] + 1 + } + + for _, n := range state.argNums { + if n >= maxArgNum { + maxArgNum = n + 1 + } + } + + if state.argNum != errIndex { + continue + } + + if state.verb == 'w' { + if anyW { + pass.Reportf(call.Pos(), "%s call has more than one error-wrapping directive %%w", state.name) + return + } + anyW = true + continue + } + + newFormat[i+1] = 'w' + + if bl, ok := call.Args[0].(*ast.BasicLit); ok { + // replace the expression, keep the arguments the same + call.Args[0] = &ast.BasicLit{ + Value: strconv.Quote(string(newFormat)), + ValuePos: bl.ValuePos, + Kind: bl.Kind, + } + } + + newExpr := render(pass.Fset, call) + + pass.Report(analysis.Diagnostic{ + Pos: call.Pos(), + Message: "call could wrap the error with error-wrapping directive %w", + SuggestedFixes: []analysis.SuggestedFix{ + { + Message: fmt.Sprintf("should replace `%s` with `%s`", oldExpr, newExpr), + TextEdits: []analysis.TextEdit{ + { + Pos: call.Pos(), + End: call.End(), + NewText: []byte(newExpr), + }, + }, + }, + }, + }) + } + + // Dotdotdot is hard. + if call.Ellipsis.IsValid() && maxArgNum >= len(call.Args)-1 { + return + } + // If any formats are indexed, extra arguments are ignored. + if anyIndex { + return + } + // There should be no leftover arguments. + if maxArgNum != len(call.Args) { + expect := maxArgNum - firstArg + numArgs := len(call.Args) - firstArg + pass.Reportf(call.Pos(), "%s call needs %v but has %v", fn.Name(), count(expect, "arg"), count(numArgs, "arg")) + } + + // If any formats are indexed, extra arguments are ignored. + if anyIndex { + return + } + + return + }) + + return nil, nil +} + +// render returns the pretty-print of the given node +func render(fset *token.FileSet, x interface{}) string { + var buf bytes.Buffer + if err := printer.Fprint(&buf, fset, x); err != nil { + panic(err) + } + return buf.String() +} + +// importPath returns the unquoted import path of s, +// or "" if the path is not properly quoted. +func importPath(s *ast.ImportSpec) string { + t, err := strconv.Unquote(s.Path.Value) + if err == nil { + return t + } + return "" +} + +// matchesSel matches the given sel slice with the selectors passed. This +// should be mostly used in conjuction with the sel() function +func matchesSel(sel []string, sels ...string) bool { + if len(sel) != len(sels) { + return false + } + + for i, s := range sel { + if s != sels[i] { + return false + } + } + + return true +} + +// sel returns the selection expression names from a call expressions. +// i.e: fmt.Errof() returns a slice of ["fmt", "Errorf"]. A call expression of +// t.Foo().Bar() returns ["t", "Foo", "Bar"] or ["t", "Foo"] depending on which +// part of the call expression is passed. A nil expr or non *ast.CallExpr +// returns nil +func sel(expr ast.Expr) []string { + if expr == nil { + return nil + } + + ce, ok := expr.(*ast.CallExpr) + if !ok { + return nil + } + + se, ok := ce.Fun.(*ast.SelectorExpr) + if !ok { + return nil + } + + res := []string{} + + if ce, ok := se.X.(*ast.CallExpr); ok { + partial := sel(ce) + res = append(res, partial...) + res = append(res, se.Sel.Name) + return res + } + + if id, ok := se.X.(*ast.Ident); ok { + res = append(res, id.Name) + } + + res = append(res, se.Sel.Name) + return res + +} + +// +// NOTE(arslan): Copied from go/analysis/passes/printf/printf.go +// + +// formatState holds the parsed representation of a printf directive such as "%3.*[4]d". +// It is constructed by parsePrintfVerb. +type formatState struct { + verb rune // the format verb: 'd' for "%d" + format string // the full format directive from % through verb, "%.3d". + name string // Printf, Sprintf etc. + flags []byte // the list of # + etc. + argNums []int // the successive argument numbers that are consumed, adjusted to refer to actual arg in call + firstArg int // Index of first argument after the format in the Printf call. + pos int // index of the verb in the format string + + // Used only during parse. + pass *analysis.Pass + call *ast.CallExpr + argNum int // Which argument we're expecting to format now. + hasIndex bool // Whether the argument is indexed. + indexPending bool // Whether we have an indexed argument that has not resolved. + nbytes int // number of bytes of the format string consumed. +} + +// formatString returns the format string argument and its index within +// the given printf-like call expression. +// +// The last parameter before variadic arguments is assumed to be +// a format string. +// +// The first string literal or string constant is assumed to be a format string +// if the call's signature cannot be determined. +// +// If it cannot find any format string parameter, it returns ("", -1). +func formatString(pass *analysis.Pass, call *ast.CallExpr) (format string, idx int) { + typ := pass.TypesInfo.Types[call.Fun].Type + if typ != nil { + if sig, ok := typ.(*types.Signature); ok { + if !sig.Variadic() { + // Skip checking non-variadic functions. + return "", -1 + } + idx := sig.Params().Len() - 2 + if idx < 0 { + // Skip checking variadic functions without + // fixed arguments. + return "", -1 + } + s, ok := stringConstantArg(pass, call, idx) + if !ok { + // The last argument before variadic args isn't a string. + return "", -1 + } + return s, idx + } + } + + // Cannot determine call's signature. Fall back to scanning for the first + // string constant in the call. + for idx := range call.Args { + if s, ok := stringConstantArg(pass, call, idx); ok { + return s, idx + } + if pass.TypesInfo.Types[call.Args[idx]].Type == types.Typ[types.String] { + // Skip checking a call with a non-constant format + // string argument, since its contents are unavailable + // for validation. + return "", -1 + } + } + return "", -1 +} + +// stringConstantArg returns call's string constant argument at the index idx. +// +// ("", false) is returned if call's argument at the index idx isn't a string +// constant. +func stringConstantArg(pass *analysis.Pass, call *ast.CallExpr, idx int) (string, bool) { + if idx >= len(call.Args) { + return "", false + } + arg := call.Args[idx] + lit := pass.TypesInfo.Types[arg].Value + if lit != nil && lit.Kind() == constant.String { + return constant.StringVal(lit), true + } + return "", false +} + +// parsePrintfVerb looks the formatting directive that begins the format string +// and returns a formatState that encodes what the directive wants, without looking +// at the actual arguments present in the call. The result is nil if there is an error. +func parsePrintfVerb(pass *analysis.Pass, call *ast.CallExpr, name, format string, firstArg, argNum int) *formatState { + state := &formatState{ + format: format, + name: name, + flags: make([]byte, 0, 5), + argNum: argNum, + argNums: make([]int, 0, 1), + nbytes: 1, // There's guaranteed to be a percent sign. + firstArg: firstArg, + pass: pass, + call: call, + } + // There may be flags. + state.parseFlags() + // There may be an index. + if !state.parseIndex() { + return nil + } + // There may be a width. + if !state.parseNum() { + return nil + } + // There may be a precision. + if !state.parsePrecision() { + return nil + } + // Now a verb, possibly prefixed by an index (which we may already have). + if !state.indexPending && !state.parseIndex() { + return nil + } + + if state.nbytes == len(state.format) { + pass.Reportf(call.Pos(), "%s format %s is missing verb at end of string", name, state.format) + return nil + } + verb, w := utf8.DecodeRuneInString(state.format[state.nbytes:]) + state.verb = verb + state.nbytes += w + if verb != '%' { + state.argNums = append(state.argNums, state.argNum) + } + state.format = state.format[:state.nbytes] + return state +} + +// parseFlags accepts any printf flags. +func (s *formatState) parseFlags() { + for s.nbytes < len(s.format) { + switch c := s.format[s.nbytes]; c { + case '#', '0', '+', '-', ' ': + s.flags = append(s.flags, c) + s.nbytes++ + default: + return + } + } +} + +// scanNum advances through a decimal number if present. +func (s *formatState) scanNum() { + for ; s.nbytes < len(s.format); s.nbytes++ { + c := s.format[s.nbytes] + if c < '0' || '9' < c { + return + } + } +} + +// parseIndex scans an index expression. It returns false if there is a syntax error. +func (s *formatState) parseIndex() bool { + if s.nbytes == len(s.format) || s.format[s.nbytes] != '[' { + return true + } + // Argument index present. + s.nbytes++ // skip '[' + start := s.nbytes + s.scanNum() + ok := true + if s.nbytes == len(s.format) || s.nbytes == start || s.format[s.nbytes] != ']' { + ok = false + s.nbytes = strings.Index(s.format, "]") + if s.nbytes < 0 { + s.pass.Reportf(s.call.Pos(), "%s format %s is missing closing ]", s.name, s.format) + return false + } + } + arg32, err := strconv.ParseInt(s.format[start:s.nbytes], 10, 32) + if err != nil || !ok || arg32 <= 0 || arg32 > int64(len(s.call.Args)-s.firstArg) { + s.pass.Reportf(s.call.Pos(), "%s format has invalid argument index [%s]", s.name, s.format[start:s.nbytes]) + return false + } + s.nbytes++ // skip ']' + arg := int(arg32) + arg += s.firstArg - 1 // We want to zero-index the actual arguments. + s.argNum = arg + s.hasIndex = true + s.indexPending = true + return true +} + +// parseNum scans a width or precision (or *). It returns false if there's a bad index expression. +func (s *formatState) parseNum() bool { + if s.nbytes < len(s.format) && s.format[s.nbytes] == '*' { + if s.indexPending { // Absorb it. + s.indexPending = false + } + s.nbytes++ + s.argNums = append(s.argNums, s.argNum) + s.argNum++ + } else { + s.scanNum() + } + return true +} + +// parsePrecision scans for a precision. It returns false if there's a bad index expression. +func (s *formatState) parsePrecision() bool { + // If there's a period, there may be a precision. + if s.nbytes < len(s.format) && s.format[s.nbytes] == '.' { + s.flags = append(s.flags, '.') // Treat precision as a flag. + s.nbytes++ + if !s.parseIndex() { + return false + } + if !s.parseNum() { + return false + } + } + return true +} + +// count(n, what) returns "1 what" or "N whats" +// (assuming the plural of what is whats). +func count(n int, what string) string { + if n == 1 { + return "1 " + what + } + return fmt.Sprintf("%d %ss", n, what) +}