Skip to content

Commit

Permalink
New Rule: cap assertion
Browse files Browse the repository at this point in the history
Trigger warnings for directly assert the builtin `cap` function.

For example:
```go
Expect(cap(slice)).To(Equal(10)) ==> Expect(slice).To(HaveCap(10))
Expect(cap(slice) == 10).To(BeTrue()) ==> Expect(slice).To(HaveCap(10))
```
  • Loading branch information
nunnatsa committed Mar 14, 2024
1 parent 6d23c4a commit 4ade2b4
Show file tree
Hide file tree
Showing 11 changed files with 357 additions and 21 deletions.
18 changes: 16 additions & 2 deletions README.md
Expand Up @@ -259,6 +259,20 @@ The output of the linter,when finding issues, looks like this:
./testdata/src/a/a.go:18:5: ginkgo-linter: wrong length assertion; consider using `Expect("").Should(BeEmpty())` instead
./testdata/src/a/a.go:22:5: ginkgo-linter: wrong length assertion; consider using `Expect("").Should(BeEmpty())` instead
```

### Wrong Cap Assertion [STYLE]
The linter finds assertion of the golang built-in `cap` function, with all kind of matchers, while there are already
gomega matchers for these usecases; We want to assert the item, rather than its cap.

There are several wrong patterns:
```go
Expect(cap(x)).To(Equal(0)) // should be: Expect(x).To(HaveCap(0))
Expect(cap(x)).To(BeZero()) // should be: Expect(x).To(HaveCap(0))
Expect(cap(x)).To(BeNumeric(">", 0)) // should be: Expect(x).ToNot(HaveCap(0))
Expect(cap(x)).To(BeNumeric("==", 2)) // should be: Expect(x).To(HaveCap(2))
Expect(cap(x)).To(BeNumeric("!=", 3)) // should be: Expect(x).ToNot(HaveCap(3))
```

#### use the `HaveLen(0)` matcher. [STYLE]
The linter will also warn about the `HaveLen(0)` matcher, and will suggest to replace it with `BeEmpty()`

Expand Down Expand Up @@ -369,7 +383,7 @@ This rule support auto fixing.

## Suppress the linter
### Suppress warning from command line
* Use the `--suppress-len-assertion=true` flag to suppress the wrong length assertion warning
* Use the `--suppress-len-assertion=true` flag to suppress the wrong length and cap assertions warning
* Use the `--suppress-nil-assertion=true` flag to suppress the wrong nil assertion warning
* Use the `--suppress-err-assertion=true` flag to suppress the wrong error assertion warning
* Use the `--suppress-compare-assertion=true` flag to suppress the wrong comparison assertion warning
Expand All @@ -380,7 +394,7 @@ This rule support auto fixing.
command line, and not from a comment.

### Suppress warning from the code
To suppress the wrong length assertion warning, add a comment with (only)
To suppress the wrong length and cap assertions warning, add a comment with (only)

`ginkgo-linter:ignore-len-assert-warning`.

Expand Down
4 changes: 4 additions & 0 deletions analyzer_test.go
Expand Up @@ -89,6 +89,10 @@ func TestAllUseCases(t *testing.T) {
testName: "issue 124: custom matcher form other packages",
testData: "a/issue-124",
},
{
testName: "cap",
testData: "a/cap",
},
} {
t.Run(tc.testName, func(tt *testing.T) {
analysistest.Run(tt, analysistest.TestData(), ginkgolinter.NewAnalyzer(), tc.testData)
Expand Down
6 changes: 6 additions & 0 deletions doc.go
Expand Up @@ -30,6 +30,12 @@ For example:
This should be replaced with:
Expect(x)).Should(HavelLen(1))
* wrong cap assertions. We want to assert the item rather than its cap. [Style]
For example:
Expect(cap(x)).Should(Equal(1))
This should be replaced with:
Expect(x)).Should(HavelCap(1))
* wrong nil assertions. We want to assert the item rather than a comparison result. [Style]
For example:
Expect(x == nil).Should(BeTrue())
Expand Down
163 changes: 149 additions & 14 deletions linter/ginkgo_linter.go
Expand Up @@ -28,6 +28,7 @@ import (
const (
linterName = "ginkgo-linter"
wrongLengthWarningTemplate = "wrong length assertion"
wrongCapWarningTemplate = "wrong cap assertion"
wrongNilWarningTemplate = "wrong nil assertion"
wrongBoolWarningTemplate = "wrong boolean assertion"
wrongErrWarningTemplate = "wrong error assertion"
Expand Down Expand Up @@ -58,6 +59,7 @@ const ( // gomega matchers
beZero = "BeZero"
equal = "Equal"
haveLen = "HaveLen"
haveCap = "HaveCap"
haveOccurred = "HaveOccurred"
haveValue = "HaveValue"
not = "Not"
Expand Down Expand Up @@ -253,6 +255,10 @@ func forceExpectTo(expr *ast.CallExpr, handler gomegahandler.Handler, reportBuil
func doCheckExpression(pass *analysis.Pass, config types.Config, assertionExp *ast.CallExpr, actualArg ast.Expr, expr *ast.CallExpr, handler gomegahandler.Handler, reportBuilder *reports.Builder) bool {
if !bool(config.SuppressLen) && isActualIsLenFunc(actualArg) {
return checkLengthMatcher(expr, pass, handler, reportBuilder)

} else if !bool(config.SuppressLen) && isActualIsCapFunc(actualArg) {
return checkCapMatcher(expr, handler, reportBuilder)

} else if nilable, compOp := getNilableFromComparison(actualArg); nilable != nil {
if isExprError(pass, nilable) {
if config.SuppressErr {
Expand All @@ -269,9 +275,16 @@ func doCheckExpression(pass *analysis.Pass, config types.Config, assertionExp *a
if !shouldContinue {
return false
}
if !bool(config.SuppressLen) && isActualIsLenFunc(first) {
if handleLenComparison(pass, expr, matcher, first, second, op, handler, reportBuilder) {
return false
if !config.SuppressLen {
if isActualIsLenFunc(first) {
if handleLenComparison(pass, expr, matcher, first, second, op, handler, reportBuilder) {
return false
}
}
if isActualIsCapFunc(first) {
if handleCapComparison(expr, matcher, first, second, op, handler, reportBuilder) {
return false
}
}
}
return bool(config.SuppressCompare) || checkComparison(expr, pass, matcher, handler, first, second, op, reportBuilder)
Expand Down Expand Up @@ -814,15 +827,50 @@ func handleLenComparison(pass *analysis.Pass, exp *ast.CallExpr, matcher *ast.Ca
return true
}

func handleCapComparison(exp *ast.CallExpr, matcher *ast.CallExpr, first ast.Expr, second ast.Expr, op token.Token, handler gomegahandler.Handler, reportBuilder *reports.Builder) bool {
switch op {
case token.EQL:
case token.NEQ:
reverseAssertionFuncLogic(exp)
default:
return false
}

eql := ast.NewIdent(haveCap)
matcher.Args = []ast.Expr{second}

handler.ReplaceFunction(matcher, eql)
firstLen, ok := first.(*ast.CallExpr) // assuming it's len()
if !ok {
return false // should never happen
}

val := firstLen.Args[0]
fun := handler.GetActualExpr(exp.Fun.(*ast.SelectorExpr))
fun.Args = []ast.Expr{val}

reportBuilder.AddIssue(true, wrongCapWarningTemplate)
return true
}

// Check if the "actual" argument is a call to the golang built-in len() function
func isActualIsLenFunc(actualArg ast.Expr) bool {
return checkActualFuncName(actualArg, "len")
}

// Check if the "actual" argument is a call to the golang built-in len() function
func isActualIsCapFunc(actualArg ast.Expr) bool {
return checkActualFuncName(actualArg, "cap")
}

func checkActualFuncName(actualArg ast.Expr, name string) bool {
lenArgExp, ok := actualArg.(*ast.CallExpr)
if !ok {
return false
}

lenFunc, ok := lenArgExp.Fun.(*ast.Ident)
return ok && lenFunc.Name == "len"
return ok && lenFunc.Name == name
}

// Check if matcher function is in one of the patterns we want to avoid
Expand All @@ -839,7 +887,7 @@ func checkLengthMatcher(exp *ast.CallExpr, pass *analysis.Pass, handler gomegaha

switch matcherFuncName {
case equal:
handleEqualMatcher(matcher, pass, exp, handler, reportBuilder)
handleEqualLenMatcher(matcher, pass, exp, handler, reportBuilder)
return false

case beZero:
Expand All @@ -859,6 +907,40 @@ func checkLengthMatcher(exp *ast.CallExpr, pass *analysis.Pass, handler gomegaha
}
}

// Check if matcher function is in one of the patterns we want to avoid
func checkCapMatcher(exp *ast.CallExpr, handler gomegahandler.Handler, reportBuilder *reports.Builder) bool {
matcher, ok := exp.Args[0].(*ast.CallExpr)
if !ok {
return true
}

matcherFuncName, ok := handler.GetActualFuncName(matcher)
if !ok {
return true
}

switch matcherFuncName {
case equal:
handleEqualCapMatcher(matcher, exp, handler, reportBuilder)
return false

case beZero:
handleCapBeZero(exp, handler, reportBuilder)
return false

case beNumerically:
return handleCapBeNumerically(matcher, exp, handler, reportBuilder)

case not:
reverseAssertionFuncLogic(exp)
exp.Args[0] = exp.Args[0].(*ast.CallExpr).Args[0]
return checkCapMatcher(exp, handler, reportBuilder)

default:
return true
}
}

// Check if matcher function is in one of the patterns we want to avoid
func checkNilMatcher(exp *ast.CallExpr, pass *analysis.Pass, nilable ast.Expr, handler gomegahandler.Handler, notEqual bool, reportBuilder *reports.Builder) bool {
matcher, ok := exp.Args[0].(*ast.CallExpr)
Expand Down Expand Up @@ -1093,13 +1175,13 @@ func replaceLenActualArg(actualExpr *ast.CallExpr, handler gomegahandler.Handler
switch name {
case expect, omega:
arg := actualExpr.Args[0]
if isActualIsLenFunc(arg) {
if isActualIsLenFunc(arg) || isActualIsCapFunc(arg) {
// replace the len function call by its parameter, to create a fix suggestion
actualExpr.Args[0] = arg.(*ast.CallExpr).Args[0]
}
case expectWithOffset:
arg := actualExpr.Args[1]
if isActualIsLenFunc(arg) {
if isActualIsLenFunc(arg) || isActualIsCapFunc(arg) {
// replace the len function call by its parameter, to create a fix suggestion
actualExpr.Args[1] = arg.(*ast.CallExpr).Args[0]
}
Expand Down Expand Up @@ -1140,18 +1222,45 @@ func handleBeNumerically(matcher *ast.CallExpr, pass *analysis.Pass, exp *ast.Ca
reverseAssertionFuncLogic(exp)
handler.ReplaceFunction(exp.Args[0].(*ast.CallExpr), ast.NewIdent(beEmpty))
exp.Args[0].(*ast.CallExpr).Args = nil
reportLengthAssertion(exp, handler, reportBuilder)
return false
} else if op == `"=="` {
chooseNumericMatcher(pass, exp, handler, valExp)
reportLengthAssertion(exp, handler, reportBuilder)
return false
} else if op == `"!="` {
reverseAssertionFuncLogic(exp)
chooseNumericMatcher(pass, exp, handler, valExp)
reportLengthAssertion(exp, handler, reportBuilder)
return false
} else {
return true
}

reportLengthAssertion(exp, handler, reportBuilder)
return false
}
return true
}

// For the BeNumerically matcher, we want to avoid the assertion of length to be > 0 or >= 1, or just == number
func handleCapBeNumerically(matcher *ast.CallExpr, exp *ast.CallExpr, handler gomegahandler.Handler, reportBuilder *reports.Builder) bool {
opExp, ok1 := matcher.Args[0].(*ast.BasicLit)
valExp, ok2 := matcher.Args[1].(*ast.BasicLit)

if ok1 && ok2 {
op := opExp.Value
val := valExp.Value

if (op == `">"` && val == "0") || (op == `">="` && val == "1") {
reverseAssertionFuncLogic(exp)
handler.ReplaceFunction(exp.Args[0].(*ast.CallExpr), ast.NewIdent(haveCap))
exp.Args[0].(*ast.CallExpr).Args = []ast.Expr{&ast.BasicLit{Kind: token.INT, Value: "0"}}
} else if op == `"=="` {
replaceNumericCapMatcher(exp, handler, valExp)
} else if op == `"!="` {
reverseAssertionFuncLogic(exp)
replaceNumericCapMatcher(exp, handler, valExp)
} else {
return true
}

reportCapAssertion(exp, handler, reportBuilder)
return false
}
return true
}
Expand All @@ -1167,6 +1276,12 @@ func chooseNumericMatcher(pass *analysis.Pass, exp *ast.CallExpr, handler gomega
}
}

func replaceNumericCapMatcher(exp *ast.CallExpr, handler gomegahandler.Handler, valExp ast.Expr) {
caller := exp.Args[0].(*ast.CallExpr)
handler.ReplaceFunction(caller, ast.NewIdent(haveCap))
exp.Args[0].(*ast.CallExpr).Args = []ast.Expr{valExp}
}

func reverseAssertionFuncLogic(exp *ast.CallExpr) {
assertionFunc := exp.Fun.(*ast.SelectorExpr).Sel
assertionFunc.Name = reverseassertion.ChangeAssertionLogic(assertionFunc.Name)
Expand All @@ -1177,7 +1292,7 @@ func isNegativeAssertion(exp *ast.CallExpr) bool {
return reverseassertion.IsNegativeLogic(assertionFunc.Name)
}

func handleEqualMatcher(matcher *ast.CallExpr, pass *analysis.Pass, exp *ast.CallExpr, handler gomegahandler.Handler, reportBuilder *reports.Builder) {
func handleEqualLenMatcher(matcher *ast.CallExpr, pass *analysis.Pass, exp *ast.CallExpr, handler gomegahandler.Handler, reportBuilder *reports.Builder) {
equalTo, ok := matcher.Args[0].(*ast.BasicLit)
if ok {
chooseNumericMatcher(pass, exp, handler, equalTo)
Expand All @@ -1188,12 +1303,25 @@ func handleEqualMatcher(matcher *ast.CallExpr, pass *analysis.Pass, exp *ast.Cal
reportLengthAssertion(exp, handler, reportBuilder)
}

func handleEqualCapMatcher(matcher *ast.CallExpr, exp *ast.CallExpr, handler gomegahandler.Handler, reportBuilder *reports.Builder) {
handler.ReplaceFunction(exp.Args[0].(*ast.CallExpr), ast.NewIdent(haveCap))
exp.Args[0].(*ast.CallExpr).Args = []ast.Expr{matcher.Args[0]}
reportCapAssertion(exp, handler, reportBuilder)
}

func handleBeZero(exp *ast.CallExpr, handler gomegahandler.Handler, reportBuilder *reports.Builder) {
exp.Args[0].(*ast.CallExpr).Args = nil
handler.ReplaceFunction(exp.Args[0].(*ast.CallExpr), ast.NewIdent(beEmpty))
reportLengthAssertion(exp, handler, reportBuilder)
}

func handleCapBeZero(exp *ast.CallExpr, handler gomegahandler.Handler, reportBuilder *reports.Builder) {
exp.Args[0].(*ast.CallExpr).Args = nil
handler.ReplaceFunction(exp.Args[0].(*ast.CallExpr), ast.NewIdent(haveCap))
exp.Args[0].(*ast.CallExpr).Args = []ast.Expr{&ast.BasicLit{Kind: token.INT, Value: "0"}}
reportCapAssertion(exp, handler, reportBuilder)
}

func handleEqualNilMatcher(matcher *ast.CallExpr, pass *analysis.Pass, exp *ast.CallExpr, handler gomegahandler.Handler, nilable ast.Expr, notEqual bool, reportBuilder *reports.Builder) {
equalTo, ok := matcher.Args[0].(*ast.Ident)
if !ok {
Expand Down Expand Up @@ -1252,6 +1380,13 @@ func reportLengthAssertion(expr *ast.CallExpr, handler gomegahandler.Handler, re
reportBuilder.AddIssue(true, wrongLengthWarningTemplate)
}

func reportCapAssertion(expr *ast.CallExpr, handler gomegahandler.Handler, reportBuilder *reports.Builder) {
actualExpr := handler.GetActualExpr(expr.Fun.(*ast.SelectorExpr))
replaceLenActualArg(actualExpr, handler)

reportBuilder.AddIssue(true, wrongCapWarningTemplate)
}

func reportNilAssertion(expr *ast.CallExpr, handler gomegahandler.Handler, nilable ast.Expr, notEqual bool, isItError bool, reportBuilder *reports.Builder) {
actualExpr := handler.GetActualExpr(expr.Fun.(*ast.SelectorExpr))
changed := replaceNilActualArg(actualExpr, handler, nilable)
Expand Down
37 changes: 37 additions & 0 deletions testdata/src/a/cap/cap.ginkgo.go
@@ -0,0 +1,37 @@
package cap

import (
. "github.com/onsi/ginkgo/v2"
"github.com/onsi/gomega"
)

var _ = Describe("check cap", func() {
It("should not allow expect cap", func() {
slice := make([]int, 0, 10)
gomega.Expect(cap(slice)).To(gomega.Equal(10)) // want `ginkgo-linter: wrong cap assertion. Consider using .gomega\.Expect\(slice\)\.To\(gomega\.HaveCap\(10\)\). instead`
gomega.Expect(cap(slice)).ToNot(gomega.Equal(0)) // want `ginkgo-linter: wrong cap assertion. Consider using .gomega\.Expect\(slice\)\.ToNot\(gomega\.HaveCap\(0\)\). instead`
gomega.Expect(cap(slice)).ToNot(gomega.Equal(5)) // want `ginkgo-linter: wrong cap assertion. Consider using .gomega\.Expect\(slice\)\.ToNot\(gomega\.HaveCap\(5\)\). instead`
gomega.Expect(cap(slice)).ToNot(gomega.BeZero()) // want `ginkgo-linter: wrong cap assertion. Consider using .gomega\.Expect\(slice\)\.ToNot\(gomega\.HaveCap\(0\)\). instead`
gomega.Expect(cap(slice)).To(gomega.BeNumerically("==", 10)) // want `ginkgo-linter: wrong cap assertion. Consider using .gomega\.Expect\(slice\)\.To\(gomega\.HaveCap\(10\)\). instead`
gomega.Expect(cap(slice)).To(gomega.BeNumerically("!=", 0)) // want `ginkgo-linter: wrong cap assertion. Consider using .gomega\.Expect\(slice\)\.ToNot\(gomega\.HaveCap\(0\)\). instead`
gomega.Expect(cap(slice)).ToNot(gomega.BeNumerically("==", 0)) // want `ginkgo-linter: wrong cap assertion. Consider using .gomega\.Expect\(slice\)\.ToNot\(gomega\.HaveCap\(0\)\). instead`
gomega.Expect(cap(slice)).To(gomega.BeNumerically("!=", 5)) // want `ginkgo-linter: wrong cap assertion. Consider using .gomega\.Expect\(slice\)\.ToNot\(gomega\.HaveCap\(5\)\). instead`
gomega.Expect(cap(slice)).ToNot(gomega.BeNumerically("==", 5)) // want `ginkgo-linter: wrong cap assertion. Consider using .gomega\.Expect\(slice\)\.ToNot\(gomega\.HaveCap\(5\)\). instead`
gomega.Expect(cap(slice)).To(gomega.BeNumerically(">", 0)) // want `ginkgo-linter: wrong cap assertion. Consider using .gomega\.Expect\(slice\)\.ToNot\(gomega\.HaveCap\(0\)\). instead`
gomega.Expect(cap(slice)).To(gomega.BeNumerically(">=", 1)) // want `ginkgo-linter: wrong cap assertion. Consider using .gomega\.Expect\(slice\)\.ToNot\(gomega\.HaveCap\(0\)\). instead`
gomega.Expect(slice).To(gomega.BeEmpty())
gomega.Expect(slice).To(gomega.HaveCap(10))
})

It("should not allow comparison with cap", func() {
slice := make([]int, 0, 10)
gomega.Expect(cap(slice) == 10).To(gomega.BeTrue()) // want `ginkgo-linter: wrong cap assertion. Consider using .gomega\.Expect\(slice\)\.To\(gomega\.HaveCap\(10\)\). instead`
gomega.Expect(cap(slice) == 10).To(gomega.Equal(true)) // want `ginkgo-linter: wrong cap assertion. Consider using .gomega\.Expect\(slice\)\.To\(gomega\.HaveCap\(10\)\). instead`
gomega.Expect(cap(slice) != 10).To(gomega.BeFalse()) // want `ginkgo-linter: wrong cap assertion. Consider using .gomega\.Expect\(slice\)\.To\(gomega\.HaveCap\(10\)\). instead`
gomega.Expect(cap(slice) != 10).To(gomega.Equal(false)) // want `ginkgo-linter: wrong cap assertion. Consider using .gomega\.Expect\(slice\)\.To\(gomega\.HaveCap\(10\)\). instead`
gomega.Expect(cap(slice) == 10).ToNot(gomega.BeFalse()) // want `ginkgo-linter: wrong cap assertion. Consider using .gomega\.Expect\(slice\)\.To\(gomega\.HaveCap\(10\)\). instead`
gomega.Expect(cap(slice) == 10).ToNot(gomega.Equal(false)) // want `ginkgo-linter: wrong cap assertion. Consider using .gomega\.Expect\(slice\)\.To\(gomega\.HaveCap\(10\)\). instead`
gomega.Expect(cap(slice) != 10).ToNot(gomega.BeTrue()) // want `ginkgo-linter: wrong cap assertion. Consider using .gomega\.Expect\(slice\)\.To\(gomega\.HaveCap\(10\)\). instead`
gomega.Expect(cap(slice) != 10).ToNot(gomega.Equal(true)) // want `ginkgo-linter: wrong cap assertion. Consider using .gomega\.Expect\(slice\)\.To\(gomega\.HaveCap\(10\)\). instead`
})
})

0 comments on commit 4ade2b4

Please sign in to comment.