Skip to content

Commit

Permalink
Merge pull request #15 from drystone/loop-vars
Browse files Browse the repository at this point in the history
  • Loading branch information
kunwardeep committed Jun 3, 2022
2 parents f435dce + 0fabdff commit 8ea8ed9
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 70 deletions.
108 changes: 39 additions & 69 deletions pkg/paralleltest/paralleltest.go
Expand Up @@ -2,6 +2,7 @@ package paralleltest

import (
"go/ast"
"go/types"
"strings"

"golang.org/x/tools/go/analysis"
Expand Down Expand Up @@ -34,9 +35,8 @@ func run(pass *analysis.Pass) (interface{}, error) {
funcDecl := node.(*ast.FuncDecl)
var funcHasParallelMethod,
rangeStatementOverTestCasesExists,
rangeStatementHasParallelMethod,
testLoopVariableReinitialised bool
var testRunLoopIdentifier string
rangeStatementHasParallelMethod bool
var loopVariableUsedInRun *string
var numberOfTestRun int
var positionOfTestRunNode []ast.Node
var rangeNode ast.Node
Expand Down Expand Up @@ -81,6 +81,13 @@ func run(pass *analysis.Pass) (interface{}, error) {
case *ast.RangeStmt:
rangeNode = v

var loopVars []types.Object
for _, expr := range []ast.Expr{v.Key, v.Value} {
if id, ok := expr.(*ast.Ident); ok {
loopVars = append(loopVars, pass.TypesInfo.ObjectOf(id))
}
}

ast.Inspect(v, func(n ast.Node) bool {
// nolint: gocritic
switch r := n.(type) {
Expand All @@ -90,26 +97,20 @@ func run(pass *analysis.Pass) (interface{}, error) {
innerTestVar := getRunCallbackParameterName(r.X)

rangeStatementOverTestCasesExists = true
testRunLoopIdentifier = methodRunFirstArgumentObjectName(r.X)

if !rangeStatementHasParallelMethod {
rangeStatementHasParallelMethod = methodParallelIsCalledInMethodRun(r.X, innerTestVar)
}

if loopVariableUsedInRun == nil {
if run, ok := r.X.(*ast.CallExpr); ok {
loopVariableUsedInRun = loopVarReferencedInRun(run, loopVars, pass.TypesInfo)
}
}
}
}
return true
})

// Check for the range loop value identifier re assignment
// More info here https://gist.github.com/kunwardeep/80c2e9f3d3256c894898bae82d9f75d0
if rangeStatementOverTestCasesExists {
var rangeValueIdentifier string
if i, ok := v.Value.(*ast.Ident); ok {
rangeValueIdentifier = i.Name
}

testLoopVariableReinitialised = testCaseLoopVariableReinitialised(v.Body.List, rangeValueIdentifier, testRunLoopIdentifier)
}
}
}

Expand All @@ -120,12 +121,8 @@ func run(pass *analysis.Pass) (interface{}, error) {
if rangeStatementOverTestCasesExists && rangeNode != nil {
if !rangeStatementHasParallelMethod {
pass.Reportf(rangeNode.Pos(), "Range statement for test %s missing the call to method parallel in test Run\n", funcDecl.Name.Name)
} else {
if testRunLoopIdentifier == "" {
pass.Reportf(rangeNode.Pos(), "Range statement for test %s does not use range value in test Run\n", funcDecl.Name.Name)
} else if !testLoopVariableReinitialised {
pass.Reportf(rangeNode.Pos(), "Range statement for test %s does not reinitialise the variable %s\n", funcDecl.Name.Name, testRunLoopIdentifier)
}
} else if loopVariableUsedInRun != nil {
pass.Reportf(rangeNode.Pos(), "Range statement for test %s does not reinitialise the variable %s\n", funcDecl.Name.Name, *loopVariableUsedInRun)
}
}

Expand All @@ -140,38 +137,6 @@ func run(pass *analysis.Pass) (interface{}, error) {
return nil, nil
}

func testCaseLoopVariableReinitialised(statements []ast.Stmt, rangeValueIdentifier string, testRunLoopIdentifier string) bool {
if len(statements) > 1 {
for _, s := range statements {
leftIdentifier, rightIdentifier := getLeftAndRightIdentifier(s)
if leftIdentifier == testRunLoopIdentifier && rightIdentifier == rangeValueIdentifier {
return true
}
}
}
return false
}

// Return the left hand side and the right hand side identifiers name
func getLeftAndRightIdentifier(s ast.Stmt) (string, string) {
var leftIdentifier, rightIdentifier string
// nolint: gocritic
switch v := s.(type) {
case *ast.AssignStmt:
if len(v.Rhs) == 1 {
if i, ok := v.Rhs[0].(*ast.Ident); ok {
rightIdentifier = i.Name
}
}
if len(v.Lhs) == 1 {
if i, ok := v.Lhs[0].(*ast.Ident); ok {
leftIdentifier = i.Name
}
}
}
return leftIdentifier, rightIdentifier
}

func methodParallelIsCalledInMethodRun(node ast.Node, testVar string) bool {
var methodParallelCalled bool
// nolint: gocritic
Expand Down Expand Up @@ -247,22 +212,6 @@ func getRunCallbackParameterName(node ast.Node) string {
return ""
}

// Gets the object name `tc` from method t.Run(tc.Foo, func(t *testing.T)
func methodRunFirstArgumentObjectName(node ast.Node) string {
// nolint: gocritic
switch n := node.(type) {
case *ast.CallExpr:
for _, arg := range n.Args {
if s, ok := arg.(*ast.SelectorExpr); ok {
if i, ok := s.X.(*ast.Ident); ok {
return i.Name
}
}
}
}
return ""
}

// Checks if the function has the param type *testing.T; if it does, then the
// parameter name is returned, too.
func isTestFunction(funcDecl *ast.FuncDecl) (bool, string) {
Expand Down Expand Up @@ -291,3 +240,24 @@ func isTestFunction(funcDecl *ast.FuncDecl) (bool, string) {

return false, ""
}

func loopVarReferencedInRun(call *ast.CallExpr, vars []types.Object, typeInfo *types.Info) (found *string) {
if len(call.Args) != 2 {
return
}

ast.Inspect(call.Args[1], func(n ast.Node) bool {
ident, ok := n.(*ast.Ident)
if !ok {
return true
}
for _, o := range vars {
if typeInfo.ObjectOf(ident) == o {
found = &ident.Name
}
}
return true
})

return
}
2 changes: 1 addition & 1 deletion pkg/paralleltest/testdata/src/t/t_test.go
Expand Up @@ -81,7 +81,7 @@ func TestFunctionRangeNotUsingRangeValueInTDotRun(t *testing.T) {
testCases := []struct {
name string
}{{name: "foo"}}
for _, tc := range testCases { // want "Range statement for test TestFunctionRangeNotUsingRangeValueInTDotRun does not use range value in test Run"
for _, tc := range testCases { // want "Range statement for test TestFunctionRangeNotUsingRangeValueInTDotRun does not reinitialise the variable tc"
t.Run("tc.name", func(t *testing.T) {
t.Parallel()
fmt.Println(tc.name)
Expand Down

0 comments on commit 8ea8ed9

Please sign in to comment.