Skip to content

Commit

Permalink
auto mark server-side request
Browse files Browse the repository at this point in the history
  • Loading branch information
kkHAIKE committed Sep 19, 2022
1 parent bb07f32 commit 95ac589
Showing 1 changed file with 37 additions and 24 deletions.
61 changes: 37 additions & 24 deletions contextcheck.go
Expand Up @@ -75,6 +75,8 @@ type resInfo struct {
// reuse for doc
ReqCtx bool
Skip bool

EntryType entryType
}

type ctxFact map[string]resInfo
Expand Down Expand Up @@ -234,10 +236,18 @@ func (r *runner) noImportedContextAndHttp(f *ssa.Function) (ret bool) {
return true
}

func (r *runner) checkIsEntry(f *ssa.Function) entryType {
func (r *runner) checkIsEntry(f *ssa.Function) (ret entryType) {
// if r.noImportedContextAndHttp(f) {
// return EntryNormal
// }
key := "entry:" + f.RelString(nil)
res, ok := r.getValue(key, f)
if ok {
return res.EntryType
}
defer func() {
r.currentFact[key] = resInfo{EntryType: ret}
}()

ctxIn, ctxOut := r.checkIsCtx(f)
if ctxOut {
Expand All @@ -264,21 +274,14 @@ func (r *runner) checkIsEntry(f *ssa.Function) entryType {
}

func (r *runner) docFlag(f *ssa.Function) (reqctx, skip bool) {
key := "doc:" + f.RelString(nil)
res, ok := r.getValue(key, f)
if ok {
return res.ReqCtx, res.Skip
}

for _, v := range r.getDocFromFunc(f) {
if len(nolintRe.FindString(v.Text)) > 0 && strings.Contains(v.Text, "contextcheck") {
res.Skip = true
skip = true
} else if strings.HasPrefix(v.Text, "// @contextcheck(req_has_ctx)") {
res.ReqCtx = true
reqctx = true
}
}
r.currentFact[key] = res
return res.ReqCtx, res.Skip
return
}

var nolintRe = regexp.MustCompile(`^//\s?nolint:`)
Expand Down Expand Up @@ -333,22 +336,30 @@ func (r *runner) checkIsCtx(f *ssa.Function) (in, out bool) {
}

func (r *runner) checkIsHttpHandler(f *ssa.Function, reqctx bool) bool {
if reqctx {
tuple := f.Signature.Params()
for i := 0; i < tuple.Len(); i++ {
if r.isHttpReqType(tuple.At(i).Type()) {
return true
}
var hasReq bool
tuple := f.Signature.Params()
for i := 0; i < tuple.Len(); i++ {
if r.isHttpReqType(tuple.At(i).Type()) {
hasReq = true
break
}
}

// must has no result
if f.Signature.Results().Len() > 0 {
if !hasReq {
return false
}
if reqctx {
return true
}

// check if use r.Context()
if f.Blocks != nil && len(r.getHttpReqCtx(f, true)) > 0 {
return true
}

// must be `func f(w http.ResponseWriter, r *http.Request) {}`
tuple := f.Signature.Params()
if f.Signature.Results().Len() > 0 {
return false
}
if tuple.Len() != 2 {
return false
}
Expand Down Expand Up @@ -420,7 +431,7 @@ func (r *runner) collectCtxRef(f *ssa.Function, isHttpHandler bool) (refMap map[
}

if isHttpHandler {
for _, v := range r.getHttpReqCtx(f) {
for _, v := range r.getHttpReqCtx(f, false) {
checkRefs(v, false)
}
} else {
Expand Down Expand Up @@ -456,7 +467,7 @@ func (r *runner) collectCtxRef(f *ssa.Function, isHttpHandler bool) (refMap map[
return
}

func (r *runner) getHttpReqCtx(f *ssa.Function) (rets []ssa.Value) {
func (r *runner) getHttpReqCtx(f *ssa.Function, least1 bool) (rets []ssa.Value) {
checkedRefMap := make(map[ssa.Value]bool)

var checkRefs func(val ssa.Value, fromAddr bool)
Expand Down Expand Up @@ -498,6 +509,9 @@ func (r *runner) getHttpReqCtx(f *ssa.Function) (rets []ssa.Value) {
if f.Signature.Recv() != nil {
// collect the return of r.Context
rets = append(rets, i.Value())
if least1 {
return
}
}
case *ssa.Store:
if !fromAddr {
Expand All @@ -516,7 +530,6 @@ func (r *runner) getHttpReqCtx(f *ssa.Function) (rets []ssa.Value) {
for _, param := range f.Params {
if r.isHttpReqType(param.Type()) {
checkRefs(param, false)
break
}
}

Expand Down

0 comments on commit 95ac589

Please sign in to comment.