From e7f8bd05b71674bd78f7fcdf1c75aa17791d20e6 Mon Sep 17 00:00:00 2001 From: sylvia7788 <1227977886@qq.com> Date: Wed, 13 Jul 2022 14:21:51 +0800 Subject: [PATCH] use fact --- contextcheck.go | 148 ++++++++++++++++++++++++------------------- contextcheck_test.go | 5 +- dep.go | 116 --------------------------------- testdata/src/a/a.go | 6 +- 4 files changed, 90 insertions(+), 185 deletions(-) delete mode 100644 dep.go diff --git a/contextcheck.go b/contextcheck.go index 0434eb8..6747d46 100644 --- a/contextcheck.go +++ b/contextcheck.go @@ -6,7 +6,6 @@ import ( "go/types" "strconv" "strings" - "sync" "github.com/gostaticanalysis/analysisutil" "golang.org/x/tools/go/analysis" @@ -23,6 +22,7 @@ func NewAnalyzer() *analysis.Analyzer { Requires: []*analysis.Analyzer{ buildssa.Analyzer, }, + FactTypes: []analysis.Fact{(*ctxFact)(nil)}, } } @@ -39,11 +39,15 @@ const ( CtxInOut = CtxIn | CtxOut ) -var ( - checkedMap = make(map[string]bool) - checkedMapLock sync.RWMutex - c *collector -) +type resInfo struct { + Valid bool + Funcs []string +} + +type ctxFact map[string]resInfo + +func (*ctxFact) String() string { return "ctxCheck" } +func (*ctxFact) AFact() {} type runner struct { pass *analysis.Pass @@ -51,15 +55,22 @@ type runner struct { ctxPTyp *types.Pointer cmpPath string skipFile map[*ast.File]bool + + currentFact ctxFact } func NewRun(pkgs []*packages.Package) func(pass *analysis.Pass) (interface{}, error) { - c = newCollector(pkgs) + m := make(map[string]bool) + for _, pkg := range pkgs { + m[strings.Split(pkg.PkgPath, "/")[0]] = true + } return func(pass *analysis.Pass) (interface{}, error) { - defer c.DecUse(pass) + // skip different repo + if !m[strings.Split(pass.Pkg.Path(), "/")[0]] { + return nil, nil + } - r := new(runner) - r.run(pass) + new(runner).run(pass) return nil, nil } } @@ -90,21 +101,33 @@ func (r *runner) run(pass *analysis.Pass) { } r.skipFile = make(map[*ast.File]bool) + r.currentFact = make(ctxFact) + var tmpFuncs []*ssa.Function for _, f := range funcs { // skip checked function key := f.RelString(nil) - _, ok := getValue(key) - if ok { + if _, ok := r.currentFact[key]; ok { continue } if !r.checkIsEntry(f, f.Pos()) { + // record the result of nomal function + checkingMap := make(map[string]bool) + checkingMap[key] = true + r.setFact(key, r.checkFuncWithoutCtx(f, checkingMap), f.Name()) continue } + tmpFuncs = append(tmpFuncs, f) + } + + for _, f := range tmpFuncs { r.checkFuncWithCtx(f) - setValue(key, true) + } + + if len(r.currentFact) > 0 { + pass.ExportPackageFact(&r.currentFact) } } @@ -269,16 +292,6 @@ func (r *runner) collectCtxRef(f *ssa.Function) (refMap map[ssa.Instruction]bool return } -func (r *runner) buildPkg(f *ssa.Function) (ff *ssa.Function) { - if f.Blocks != nil { - ff = f - return - } - - ff = c.GetFunction(f) - return -} - func (r *runner) checkIsSameRepo(s string) bool { return strings.HasPrefix(s, r.cmpPath+"/") } @@ -313,31 +326,10 @@ func (r *runner) checkFuncWithCtx(f *ssa.Function) { } key := ff.RelString(nil) - valid, ok := getValue(key) + res, ok := r.getValue(key, ff) if ok { - if !valid { - r.pass.Reportf(instr.Pos(), "Function `%s` should pass the context parameter", ff.Name()) - } - continue - } - - // check is thunk or bound - if strings.HasSuffix(key, "$thunk") || strings.HasSuffix(key, "$bound") { - continue - } - - // if ff has no ctx, start deep traversal check - if !r.checkIsEntry(ff, instr.Pos()) { - if ff = r.buildPkg(ff); ff == nil { - continue - } - - checkingMap := make(map[string]bool) - checkingMap[key] = true - valid := r.checkFuncWithoutCtx(ff, checkingMap) - setValue(key, valid) - if !valid { - r.pass.Reportf(instr.Pos(), "Function `%s` should pass the context parameter", ff.Name()) + if !res.Valid { + r.pass.Reportf(instr.Pos(), "Function `%s` should pass the context parameter", strings.Join(reverse(res.Funcs), "->")) } } } @@ -346,6 +338,7 @@ func (r *runner) checkFuncWithCtx(f *ssa.Function) { func (r *runner) checkFuncWithoutCtx(f *ssa.Function, checkingMap map[string]bool) (ret bool) { ret = true + orgKey := f.RelString(nil) for _, b := range f.Blocks { for _, instr := range b.Instrs { tp, ok := r.getCtxType(instr) @@ -362,7 +355,6 @@ func (r *runner) checkFuncWithoutCtx(f *ssa.Function, checkingMap map[string]boo if tp&CtxInField == 0 { ret = false } - continue } ff := r.getFunction(instr) @@ -371,11 +363,13 @@ func (r *runner) checkFuncWithoutCtx(f *ssa.Function, checkingMap map[string]boo } key := ff.RelString(nil) - valid, ok := getValue(key) + res, ok := r.getValue(key, ff) if ok { - if !valid { + if !res.Valid { ret = false - r.pass.Reportf(instr.Pos(), "Function `%s` should pass the context parameter", ff.Name()) + + // save the call link + r.setFact(orgKey, res.Valid, res.Funcs...) } continue } @@ -386,21 +380,21 @@ func (r *runner) checkFuncWithoutCtx(f *ssa.Function, checkingMap map[string]boo } if !r.checkIsEntry(ff, instr.Pos()) { - // handler ring call - if checkingMap[key] { + // cannot get info from fact, skip + if ff.Blocks == nil { continue } - checkingMap[key] = true - if ff = r.buildPkg(ff); ff == nil { + // handler ring call + if checkingMap[key] { continue } + checkingMap[key] = true valid := r.checkFuncWithoutCtx(ff, checkingMap) - setValue(key, valid) + r.setFact(orgKey, valid, ff.Name()) if !valid { ret = false - r.pass.Reportf(instr.Pos(), "Function `%s` should pass the context parameter", ff.Name()) } } } @@ -501,15 +495,39 @@ func (r *runner) isCtxType(tp types.Type) bool { return types.Identical(tp, r.ctxTyp) || types.Identical(tp, r.ctxPTyp) } -func getValue(key string) (valid, ok bool) { - checkedMapLock.RLock() - valid, ok = checkedMap[key] - checkedMapLock.RUnlock() +func (r *runner) getValue(key string, f *ssa.Function) (res resInfo, ok bool) { + res, ok = r.currentFact[key] + if ok { + return + } + + if f.Pkg == nil { + return + } + + var fact ctxFact + if r.pass.ImportPackageFact(f.Pkg.Pkg, &fact) { + res, ok = fact[key] + } return } -func setValue(key string, valid bool) { - checkedMapLock.Lock() - checkedMap[key] = valid - checkedMapLock.Unlock() +func (r *runner) setFact(key string, valid bool, funcs ...string) { + r.currentFact[key] = resInfo{ + Valid: valid, + Funcs: append(r.currentFact[key].Funcs, funcs...), + } +} + +func reverse(arr1 []string) (arr2 []string) { + l := len(arr1) + if l == 0 { + return + } + arr2 = make([]string, l) + for i := 0; i <= l/2; i++ { + arr2[i] = arr1[l-1-i] + arr2[l-1-i] = arr1[i] + } + return } diff --git a/contextcheck_test.go b/contextcheck_test.go index 7d9421f..ce7f8fa 100644 --- a/contextcheck_test.go +++ b/contextcheck_test.go @@ -6,10 +6,13 @@ import ( "github.com/sylvia7788/contextcheck" "golang.org/x/tools/go/analysis/analysistest" + "golang.org/x/tools/go/packages" ) func Test(t *testing.T) { log.SetFlags(log.Lshortfile) testdata := analysistest.TestData() - analysistest.Run(t, testdata, contextcheck.NewAnalyzer(), "a") + analyzer := contextcheck.NewAnalyzer() + analyzer.Run = contextcheck.NewRun([]*packages.Package{{PkgPath: "a"}}) + analysistest.Run(t, testdata, analyzer, "a") } diff --git a/dep.go b/dep.go deleted file mode 100644 index ecef1cf..0000000 --- a/dep.go +++ /dev/null @@ -1,116 +0,0 @@ -package contextcheck - -import ( - "go/types" - "sync/atomic" - - "golang.org/x/tools/go/analysis" - "golang.org/x/tools/go/analysis/passes/buildssa" - "golang.org/x/tools/go/packages" - "golang.org/x/tools/go/ssa" -) - -type pkgInfo struct { - pkgPkg *packages.Package // to find references later - ssaPkg *ssa.Package // to find func which has been built - refCnt int32 // reference count -} - -type collector struct { - m map[string]*pkgInfo -} - -func newCollector(pkgs []*packages.Package) (c *collector) { - c = &collector{ - m: make(map[string]*pkgInfo), - } - - // self-reference - for _, pkg := range pkgs { - c.m[pkg.PkgPath] = &pkgInfo{ - pkgPkg: pkg, - refCnt: 1, - } - } - - // import reference - for _, pkg := range pkgs { - for _, imp := range pkg.Imports { - if val, ok := c.m[imp.PkgPath]; ok { - val.refCnt++ - } - } - } - - return -} - -func (c *collector) DecUse(pass *analysis.Pass) { - curPkg, ok := c.m[pass.Pkg.Path()] - if !ok { - return - } - - if atomic.AddInt32(&curPkg.refCnt, -1) != 0 { - curPkg.ssaPkg = pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA).Pkg - return - } - - var release func(info *pkgInfo) - release = func(info *pkgInfo) { - for _, pkg := range info.pkgPkg.Imports { - if val, ok := c.m[pkg.PkgPath]; ok { - if atomic.AddInt32(&val.refCnt, -1) == 0 { - release(val) - } - } - } - - info.pkgPkg = nil - info.ssaPkg = nil - } - release(curPkg) -} - -func (c *collector) GetFunction(f *ssa.Function) (ff *ssa.Function) { - info, ok := c.m[f.Pkg.Pkg.Path()] - if !ok { - return - } - - // without recv => get by Func - recv := f.Signature.Recv() - if recv == nil { - ff = info.ssaPkg.Func(f.Name()) - return - } - - // with recv => find in prog according to type - ntp, ptp := getNamedType(recv.Type()) - if ntp == nil { - return - } - sel := info.ssaPkg.Prog.MethodSets.MethodSet(ntp).Lookup(ntp.Obj().Pkg(), f.Name()) - if sel == nil { - sel = info.ssaPkg.Prog.MethodSets.MethodSet(ptp).Lookup(ntp.Obj().Pkg(), f.Name()) - } - if sel == nil { - return - } - ff = info.ssaPkg.Prog.MethodValue(sel) - return -} - -func getNamedType(tp types.Type) (ntp *types.Named, ptp *types.Pointer) { - switch t := tp.(type) { - case *types.Named: - ntp = t - ptp = types.NewPointer(tp) - case *types.Pointer: - if n, ok := t.Elem().(*types.Named); ok { - ntp = n - ptp = t - } - } - return -} diff --git a/testdata/src/a/a.go b/testdata/src/a/a.go index 3ae0dba..5db336e 100644 --- a/testdata/src/a/a.go +++ b/testdata/src/a/a.go @@ -1,4 +1,4 @@ -package a +package a // want package:"ctxCheck" import "context" @@ -35,7 +35,7 @@ func f1(ctx context.Context) { newXX().Test() f3() // want "Function `f3` should pass the context parameter" - f6() // want "Function `f6` should pass the context parameter" + f6() // want "Function `f6->f3` should pass the context parameter" defer func() { f2(ctx) @@ -76,7 +76,7 @@ func f5(ctx context.Context) { } func f6() { - f3() // want "Function `f3` should pass the context parameter" + f3() } func f7(ctx context.Context) {