From 7643f65052f792afbe6fcdd344afebe112a395da Mon Sep 17 00:00:00 2001 From: Dominik Honnef Date: Fri, 1 Jan 2021 10:20:11 +0100 Subject: [PATCH] staticcheck: improve extractConsts helper In commit ae6041020cd7fb7ad9a5d6c251130b93487d0a73 we changed extractConsts to not collect all distinct constants from phi nodes. We didn't notice that the function no longer returned more than one constant and didn't need to return a slice anymore. In this change, we change the signature and name, and replace all looping with a single if != nil check. Additionally, we introduce the extractConstExpectKind helper, which takes care of checking that the constant is of the right kind, and not untyped nil. This helper removes a lot of boilerplate. (cherry picked from commit 173553237fc1eb9293e0810c58bba23c08407546) --- staticcheck/lint.go | 2 +- staticcheck/rules.go | 64 +++++++++++++------------------------------- 2 files changed, 19 insertions(+), 47 deletions(-) diff --git a/staticcheck/lint.go b/staticcheck/lint.go index b3274c3e3..0dec8f26c 100644 --- a/staticcheck/lint.go +++ b/staticcheck/lint.go @@ -3154,7 +3154,7 @@ func CheckWriterBufferModified(pass *analysis.Pass) (interface{}, error) { func loopedRegexp(name string) CallCheck { return func(call *Call) { - if len(extractConsts(call.Args[0].Value.Value)) == 0 { + if extractConst(call.Args[0].Value.Value) == nil { return } if !isInLoop(call.Instr.Block()) { diff --git a/staticcheck/rules.go b/staticcheck/rules.go index 4bc6f7a5e..d4e4c1968 100644 --- a/staticcheck/rules.go +++ b/staticcheck/rules.go @@ -56,26 +56,28 @@ func (arg *Argument) Invalid(msg string) { type CallCheck func(call *Call) -func extractConsts(v ir.Value) []*ir.Const { +func extractConstExpectKind(v ir.Value, kind constant.Kind) *ir.Const { + k := extractConst(v) + if k == nil || k.Value == nil || k.Value.Kind() != kind { + return nil + } + return k +} + +func extractConst(v ir.Value) *ir.Const { v = irutil.Flatten(v) switch v := v.(type) { case *ir.Const: - return []*ir.Const{v} + return v case *ir.MakeInterface: - return extractConsts(v.X) + return extractConst(v.X) default: return nil } } func ValidateRegexp(v Value) error { - for _, c := range extractConsts(v.Value) { - if c.Value == nil { - continue - } - if c.Value.Kind() != constant.String { - continue - } + if c := extractConstExpectKind(v.Value, constant.String); c != nil { s := constant.StringVal(c.Value) if _, err := regexp.Compile(s); err != nil { return err @@ -85,13 +87,7 @@ func ValidateRegexp(v Value) error { } func ValidateTimeLayout(v Value) error { - for _, c := range extractConsts(v.Value) { - if c.Value == nil { - continue - } - if c.Value.Kind() != constant.String { - continue - } + if c := extractConstExpectKind(v.Value, constant.String); c != nil { s := constant.StringVal(c.Value) s = strings.Replace(s, "_", " ", -1) s = strings.Replace(s, "Z", "-", -1) @@ -104,13 +100,7 @@ func ValidateTimeLayout(v Value) error { } func ValidateURL(v Value) error { - for _, c := range extractConsts(v.Value) { - if c.Value == nil { - continue - } - if c.Value.Kind() != constant.String { - continue - } + if c := extractConstExpectKind(v.Value, constant.String); c != nil { s := constant.StringVal(c.Value) _, err := url.Parse(s) if err != nil { @@ -121,13 +111,7 @@ func ValidateURL(v Value) error { } func InvalidUTF8(v Value) bool { - for _, c := range extractConsts(v.Value) { - if c.Value == nil { - continue - } - if c.Value.Kind() != constant.String { - continue - } + if c := extractConstExpectKind(v.Value, constant.String); c != nil { s := constant.StringVal(c.Value) if !utf8.ValidString(s) { return true @@ -269,13 +253,7 @@ func validatePort(s string) bool { } func ValidHostPort(v Value) bool { - for _, k := range extractConsts(v.Value) { - if k.Value == nil { - continue - } - if k.Value.Kind() != constant.String { - continue - } + if k := extractConstExpectKind(v.Value, constant.String); k != nil { s := constant.StringVal(k.Value) _, port, err := net.SplitHostPort(s) if err != nil { @@ -296,17 +274,11 @@ func ConvertedFrom(v Value, typ string) bool { } func UniqueStringCutset(v Value) bool { - for _, c := range extractConsts(v.Value) { - if c.Value == nil { - continue - } - if c.Value.Kind() != constant.String { - continue - } + if c := extractConstExpectKind(v.Value, constant.String); c != nil { s := constant.StringVal(c.Value) rs := runeSlice(s) if len(rs) < 2 { - continue + return true } sort.Sort(rs) for i, r := range rs[1:] {