From ca632157dbafe114fc5f8300b49809b7a577722a Mon Sep 17 00:00:00 2001 From: sylvia7788 <1227977886@qq.com> Date: Mon, 15 Aug 2022 17:55:10 +0800 Subject: [PATCH] feat: support check r.Context() --- contextcheck.go | 173 ++++++++++++++++---------------------------- testdata/src/a/a.go | 21 ++++-- 2 files changed, 78 insertions(+), 116 deletions(-) diff --git a/contextcheck.go b/contextcheck.go index 008914f..c0142b6 100644 --- a/contextcheck.go +++ b/contextcheck.go @@ -38,17 +38,14 @@ const ( CtxIn int = 1 << iota // ctx in function's param CtxOut // ctx in function's results CtxInField // ctx in function's field param - HttpRes // http.ResponseWriter in function's param - HttpReq // *http.Request in function's param - - HttpHandler = HttpRes | HttpReq ) -const ( - EntryWithCtx int = 1 << iota // has ctx in - EntryWithHttpHandler // is http handler +type entryType int - Entry = EntryWithCtx | EntryWithHttpHandler +const ( + EntryNone entryType = iota + EntryWithCtx // has ctx in + EntryWithHttpHandler // is http handler ) type resInfo struct { @@ -109,7 +106,7 @@ func (r *runner) run(pass *analysis.Pass) { type entryInfo struct { f *ssa.Function // entryfunc - tp int // entrytype + tp entryType // entrytype } var tmpFuncs []entryInfo for _, f := range funcs { @@ -119,7 +116,7 @@ func (r *runner) run(pass *analysis.Pass) { continue } - if entryType := r.checkIsEntry(f); entryType&Entry == 0 { + if entryType := r.checkIsEntry(f); entryType == EntryNone { // record the result of nomal function checkingMap := make(map[string]bool) checkingMap[key] = true @@ -161,7 +158,7 @@ func (r *runner) getRequiedType(pssa *buildssa.SSA, path, name string) (obj *typ func (r *runner) collectHttpTyps(pssa *buildssa.SSA) { objRes, pobjRes, ok := r.getRequiedType(pssa, httpPkg, httpRes) if ok { - r.httpResTyps = append(r.httpResTyps, objRes, pobjRes, types.NewPointer(pobjRes)) + r.httpResTyps = append(r.httpResTyps, objRes, pobjRes) } objReq, pobjReq, ok := r.getRequiedType(pssa, httpPkg, httpReq) @@ -201,27 +198,26 @@ func (r *runner) noImportedContextAndHttp(f *ssa.Function) (ret bool) { return true } -func (r *runner) checkIsEntry(f *ssa.Function) (entryType int) { +func (r *runner) checkIsEntry(f *ssa.Function) entryType { if r.noImportedContextAndHttp(f) { - return + return EntryNone } ctxIn, ctxOut := r.checkIsCtx(f) if ctxOut { // skip the function which generate ctx - return + return EntryNone } else if ctxIn { // has ctx in, ignore *http.Request.Context() - entryType |= EntryWithCtx - return + return EntryWithCtx } // check is `func handler(w http.ResponseWriter, r *http.Request) {}` if r.checkIsHttpHandler(f) { - entryType |= EntryWithHttpHandler + return EntryWithHttpHandler } - return + return EntryNone } func (r *runner) checkIsCtx(f *ssa.Function) (in, out bool) { @@ -259,39 +255,12 @@ func (r *runner) checkIsHttpHandler(f *ssa.Function) bool { return false } - // must has http.ResponseWriter and *http.Request in param or freevar - var tp int - - // check params + // must be `func f(w http.ResponseWriter, r *http.Request) {}` tuple := f.Signature.Params() - for i := 0; i < tuple.Len(); i++ { - if r.isCtxType(tuple.At(i).Type()) { - return false - } else if r.isHttpReqType(tuple.At(i).Type()) { - tp |= HttpReq - } else if r.isHttpResType(tuple.At(i).Type()) { - tp |= HttpRes - } - if tp == HttpHandler { - return true - } - } - - // check freevars - for _, param := range f.FreeVars { - if r.isCtxType(param.Type()) { - return false - } else if r.isHttpReqType(param.Type()) { - tp |= HttpReq - } else if r.isHttpResType(param.Type()) { - tp |= HttpRes - } - if tp == HttpHandler { - return true - } + if tuple.Len() != 2 { + return false } - - return false + return r.isHttpResType(tuple.At(0).Type()) && r.isHttpReqType(tuple.At(1).Type()) } func (r *runner) collectCtxRef(f *ssa.Function, isHttpHandler bool) (refMap map[ssa.Instruction]bool, ok bool) { @@ -358,15 +327,21 @@ func (r *runner) collectCtxRef(f *ssa.Function, isHttpHandler bool) (refMap map[ } } - for _, param := range f.Params { - if r.isCtxType(param.Type()) { - checkRefs(param, false) + if isHttpHandler { + for _, v := range r.getHttpReqCtx(f) { + checkRefs(v, false) + } + } else { + for _, param := range f.Params { + if r.isCtxType(param.Type()) { + checkRefs(param, false) + } } - } - for _, param := range f.FreeVars { - if r.isCtxType(param.Type()) { - checkRefs(param, false) + for _, param := range f.FreeVars { + if r.isCtxType(param.Type()) { + checkRefs(param, false) + } } } @@ -386,14 +361,6 @@ func (r *runner) collectCtxRef(f *ssa.Function, isHttpHandler bool) (refMap map[ } } - if !isHttpHandler { - return - } - - for _, v := range r.getHttpReqCtx(f) { - checkRefs(v, false) - } - return } @@ -421,40 +388,34 @@ func (r *runner) getHttpReqCtx(f *ssa.Function) (rets []ssa.Value) { checkInstr = func(instr ssa.Instruction, fromAddr bool) { switch i := instr.(type) { case ssa.CallInstruction: + // r.Context() only has one recv + if len(i.Common().Args) != 1 { + break + } + // find r.Context() if r.getCallInstrCtxType(i)&CtxOut != CtxOut { break } - for _, v := range i.Common().Args { - if !r.isHttpReqType(v.Type()) { - continue - } - - f := r.getFunction(instr) - if f == nil { - continue - } - - // check is r.Context - if f.Signature.Recv() != nil && r.isHttpReqType(f.Signature.Recv().Type()) && f.Name() == ctxName { - // collect the return of r.Context - rets = append(rets, i.Value()) - } + // check is r.Context + f := r.getFunction(instr) + if f == nil || f.Name() != ctxName { + break + } + if f.Signature.Recv() != nil { + // collect the return of r.Context + rets = append(rets, i.Value()) } case *ssa.Store: if !fromAddr { checkRefs(i.Addr, true) } case *ssa.UnOp: - if r.isHttpReqType(i.Type()) { - checkRefs(i, false) - } - case *ssa.MakeClosure: + checkRefs(i, false) case *ssa.Phi: - if r.isHttpReqType(i.Type()) { - checkRefs(i, false) - } + checkRefs(i, false) + case *ssa.MakeClosure: case *ssa.Extract: // http.Request can only be input } @@ -463,20 +424,15 @@ func (r *runner) getHttpReqCtx(f *ssa.Function) (rets []ssa.Value) { for _, param := range f.Params { if r.isHttpReqType(param.Type()) { checkRefs(param, false) - } - } - - for _, param := range f.FreeVars { - if r.isHttpReqType(param.Type()) { - checkRefs(param, false) + break } } return } -func (r *runner) checkFuncWithCtx(f *ssa.Function, tp int) { - isHttpHandler := tp&EntryWithHttpHandler != 0 +func (r *runner) checkFuncWithCtx(f *ssa.Function, tp entryType) { + isHttpHandler := tp == EntryWithHttpHandler refMap, ok := r.collectCtxRef(f, isHttpHandler) if !ok { return @@ -496,15 +452,14 @@ func (r *runner) checkFuncWithCtx(f *ssa.Function, tp int) { if tp&CtxIn != 0 { if !refMap[instr] { - r.pass.Reportf(instr.Pos(), "Non-inherited new context, use function like `context.WithXXX` or `r.Context` instead") + if isHttpHandler { + r.pass.Reportf(instr.Pos(), "Non-inherited new context, use function like `context.WithXXX` or `r.Context` instead") + } else { + r.pass.Reportf(instr.Pos(), "Non-inherited new context, use function like `context.WithXXX` instead") + } } } - // only check if the ctx used in the current function is r.Context() - if isHttpHandler { - continue - } - ff := r.getFunction(instr) if ff == nil { continue @@ -564,13 +519,13 @@ func (r *runner) checkFuncWithoutCtx(f *ssa.Function, checkingMap map[string]boo continue } - if entryType := r.checkIsEntry(ff); entryType&Entry == 0 { + if entryType := r.checkIsEntry(ff); entryType == EntryNone { // cannot get info from fact, skip if ff.Blocks == nil { continue } - // handler ring call + // handler cycle call if checkingMap[key] { continue } @@ -681,23 +636,21 @@ func (r *runner) isCtxType(tp types.Type) bool { } func (r *runner) isHttpResType(tp types.Type) bool { - var ok bool for _, v := range r.httpResTyps { - if ok = types.Identical(v, v); ok { - break + if ok := types.Identical(v, v); ok { + return true } } - return ok + return false } func (r *runner) isHttpReqType(tp types.Type) bool { - var ok bool for _, v := range r.httpReqTyps { - if ok = types.Identical(tp, v); ok { - break + if ok := types.Identical(tp, v); ok { + return true } } - return ok + return false } func (r *runner) getValue(key string, f *ssa.Function) (res resInfo, ok bool) { diff --git a/testdata/src/a/a.go b/testdata/src/a/a.go index ac9e401..6c4b985 100644 --- a/testdata/src/a/a.go +++ b/testdata/src/a/a.go @@ -48,7 +48,7 @@ func f1(ctx context.Context) { f2(ctx) }(ctx) - f2(context.Background()) // want "Non-inherited new context, use function like `context.WithXXX` or `r.Context` instead" + f2(context.Background()) // want "Non-inherited new context, use function like `context.WithXXX` instead" thunk := MyInt.F thunk(0) @@ -66,7 +66,7 @@ func f3() { func f4(ctx context.Context) { f2(ctx) ctx = context.Background() - f2(ctx) // want "Non-inherited new context, use function like `context.WithXXX` or `r.Context` instead" + f2(ctx) // want "Non-inherited new context, use function like `context.WithXXX` instead" } func f5(ctx context.Context) { @@ -104,11 +104,20 @@ func f9(w http.ResponseWriter, r *http.Request) { f8(context.Background(), w, r) // want "Non-inherited new context, use function like `context.WithXXX` or `r.Context` instead" } -func f10() { +func f10(in bool, w http.ResponseWriter, r *http.Request) { + f8(r.Context(), w, r) + f8(context.Background(), w, r) +} + +func f11() { http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - f9(w, r) f8(r.Context(), w, r) f8(context.Background(), w, r) // want "Non-inherited new context, use function like `context.WithXXX` or `r.Context` instead" + + f9(w, r) + + // f10 should be like `func f10(ctx context.Context, in bool, w http.ResponseWriter, r *http.Request)` + f10(true, w, r) // want "Function `f10` should pass the context parameter" }) } @@ -116,7 +125,7 @@ func f10() { type MySlice[T int | float32] []T -func (s MySlice[T]) f11(ctx context.Context) T { +func (s MySlice[T]) f12(ctx context.Context) T { f3() // generics, Block is nil, wont report var sum T @@ -126,7 +135,7 @@ func (s MySlice[T]) f11(ctx context.Context) T { return sum } -func f12[T int | int8 | int16 | int32 | int64 | uint | uint8 | uint16 | uint32 | uint64 | float32 | float64](ctx context.Context, a, b T) T { +func f13[T int | int8 | int16 | int32 | int64 | uint | uint8 | uint16 | uint32 | uint64 | float32 | float64](ctx context.Context, a, b T) T { f3() // generics, Block is nil, wont report if a > b {