diff --git a/process.go b/process.go index c4a7724..6e61ee5 100644 --- a/process.go +++ b/process.go @@ -40,20 +40,7 @@ func Process(filename string, src []byte) ([]byte, error) { return false } if fd.Body != nil { - var prefix string - if fd.Recv != nil && len(fd.Recv.List) > 0 { - if rn, ok := fd.Recv.List[0].Type.(*ast.StarExpr); ok { - if idt, ok := rn.X.(*ast.Ident); ok { - prefix = genSegName(idt.Name) - } - } else if idt, ok := fd.Recv.List[0].Type.(*ast.Ident); ok { - prefix = genSegName(idt.Name) - } - } - sn := genSegName(fd.Name.Name) - if len(prefix) != 0 { - sn = prefix + "_" + sn - } + sn := getSegName(fd) vn, t := parseParams(fd.Type) var ds ast.Stmt switch t { @@ -141,37 +128,25 @@ func existFromContext(pn string, s ast.Stmt) bool { // ex: // defer newrelic.FromContext(ctx).StartSegment("slow").End() func buildDeferStmt(pos token.Pos, pkgName, ctxName, segName string) *ast.DeferStmt { - return &ast.DeferStmt{ - Call: &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: &ast.Ident{NamePos: pos, Name: pkgName}, - Sel: &ast.Ident{NamePos: pos, Name: "FromContext"}, - }, - Args: []ast.Expr{&ast.Ident{NamePos: pos, Name: ctxName}}, - }, - Sel: &ast.Ident{NamePos: pos, Name: "StartSegment"}, - }, - Args: []ast.Expr{&ast.BasicLit{ - ValuePos: pos, - Kind: token.STRING, - Value: strconv.Quote(segName), - }}, - }, - Sel: &ast.Ident{NamePos: pos, Name: "End"}, - }, - Rparen: pos, - }, - } + arg := &ast.Ident{NamePos: pos, Name: ctxName} + return skeletonDeferStmt(pos, arg, pkgName, segName) } // buildDeferStmt builds the defer statement with *http.Request. // ex: // defer newrelic.FromContext(req.Context()).StartSegment("slow").End() func buildDeferStmtWithHttpRequest(pos token.Pos, pkgName, reqName, segName string) *ast.DeferStmt { + arg := &ast.CallExpr{ + Fun: &ast.SelectorExpr{ + X: &ast.Ident{NamePos: pos, Name: reqName}, + Sel: &ast.Ident{NamePos: pos, Name: "Context"}, + }, + Rparen: pos, + } + return skeletonDeferStmt(pos, arg, pkgName, segName) +} + +func skeletonDeferStmt(pos token.Pos, fcArg ast.Expr, pkgName, segName string) *ast.DeferStmt { return &ast.DeferStmt{ Call: &ast.CallExpr{ Fun: &ast.SelectorExpr{ @@ -183,15 +158,7 @@ func buildDeferStmtWithHttpRequest(pos token.Pos, pkgName, reqName, segName stri Sel: &ast.Ident{NamePos: pos, Name: "FromContext"}, }, Lparen: pos, - Args: []ast.Expr{ - &ast.CallExpr{ - Fun: &ast.SelectorExpr{ - X: &ast.Ident{NamePos: pos, Name: reqName}, - Sel: &ast.Ident{NamePos: pos, Name: "Context"}, - }, - Rparen: pos, - }, - }, + Args: []ast.Expr{fcArg}, Rparen: pos, }, Sel: &ast.Ident{NamePos: pos, Name: "StartSegment"}, @@ -209,11 +176,29 @@ func buildDeferStmtWithHttpRequest(pos token.Pos, pkgName, reqName, segName stri } } +func getSegName(fd *ast.FuncDecl) string { + var prefix string + if fd.Recv != nil && len(fd.Recv.List) > 0 { + if rn, ok := fd.Recv.List[0].Type.(*ast.StarExpr); ok { + if idt, ok := rn.X.(*ast.Ident); ok { + prefix = toSnake(idt.Name) + } + } else if idt, ok := fd.Recv.List[0].Type.(*ast.Ident); ok { + prefix = toSnake(idt.Name) + } + } + sn := toSnake(fd.Name.Name) + if len(prefix) != 0 { + sn = prefix + "_" + sn + } + return sn +} + // https://www.golangprograms.com/golang-convert-string-into-snake-case.html var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)") var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])") -func genSegName(n string) string { +func toSnake(n string) string { snake := matchFirstCap.ReplaceAllString(n, "${1}_${2}") snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}") return strings.ToLower(snake) diff --git a/process_test.go b/process_test.go index 1dfd3ea..3097d54 100644 --- a/process_test.go +++ b/process_test.go @@ -306,8 +306,8 @@ func Test_genSegName(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - if got := genSegName(tt.n); got != tt.want { - t.Errorf("genSegName() = %q, want %q", got, tt.want) + if got := toSnake(tt.n); got != tt.want { + t.Errorf("toSnake() = %q, want %q", got, tt.want) } }) }