Skip to content

Commit

Permalink
optimize(hz): optimize the use experience for hz client (#605)
Browse files Browse the repository at this point in the history
  • Loading branch information
FGYFFFF authored and wzekin committed Feb 23, 2023
1 parent 2d8cc9b commit dae22e5
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 36 deletions.
2 changes: 2 additions & 0 deletions cmd/hz/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ func Init() *cli.App {
useFlag := cli.StringFlag{Name: "use", Usage: "Specify the model package to import for handler.", Destination: &globalArgs.Use}
baseDomainFlag := cli.StringFlag{Name: "base_domain", Usage: "Specify the request domain.", Destination: &globalArgs.BaseDomain}
clientDirFlag := cli.StringFlag{Name: "client_dir", Usage: "Specify the client path. If not specified, IDL generated path is used for 'client' command; no client code is generated for 'new' command", Destination: &globalArgs.ClientDir}
forceClientDirFlag := cli.StringFlag{Name: "force_client_dir", Usage: "Specify the client path, and won't use namespaces as subpaths", Destination: &globalArgs.ForceClientDir}

optPkgFlag := cli.StringSliceFlag{Name: "option_package", Aliases: []string{"P"}, Usage: "Specify the package path. ({include_path}={import_path})"}
includesFlag := cli.StringSliceFlag{Name: "proto_path", Aliases: []string{"I"}, Usage: "Add an IDL search path for includes. (Valid only if idl is protobuf)"}
Expand Down Expand Up @@ -294,6 +295,7 @@ func Init() *cli.App {
&modelDirFlag,
&clientDirFlag,
&useFlag,
&forceClientDirFlag,

&includesFlag,
&thriftOptionsFlag,
Expand Down
19 changes: 10 additions & 9 deletions cmd/hz/config/argument.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,16 @@ import (

type Argument struct {
// Mode meta.Mode // operating mode(0-compiler, 1-plugin)
CmdType string // command type
Verbose bool // print verbose log
Cwd string // execution path
OutDir string // output path
HandlerDir string // handler path
ModelDir string // model path
RouterDir string // router path
ClientDir string // client path
BaseDomain string // request domain
CmdType string // command type
Verbose bool // print verbose log
Cwd string // execution path
OutDir string // output path
HandlerDir string // handler path
ModelDir string // model path
RouterDir string // router path
ClientDir string // client path
BaseDomain string // request domain
ForceClientDir string // client dir (not use namespace as a subpath)

IdlType string // idl type
IdlPaths []string // master idl path
Expand Down
5 changes: 4 additions & 1 deletion cmd/hz/generator/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ type ClientFile struct {
func (pkgGen *HttpPackageGenerator) genClient(pkg *HttpPackage, clientDir string) error {
for _, s := range pkg.Services {
cliDir := util.SubDir(clientDir, util.ToSnakeCase(s.Name))
if len(pkgGen.ForceClientDir) != 0 {
cliDir = pkgGen.ForceClientDir
}
hertzClientPath := filepath.Join(cliDir, hertzClientTplName)
isExist, err := util.PathExist(hertzClientPath)
if err != nil {
Expand All @@ -57,7 +60,7 @@ func (pkgGen *HttpPackageGenerator) genClient(pkg *HttpPackage, clientDir string
}
client := ClientFile{
FilePath: filepath.Join(cliDir, util.ToSnakeCase(s.Name)+".go"),
PackageName: util.ToSnakeCase(s.Name),
PackageName: util.ToSnakeCase(filepath.Base(cliDir)),
ServiceName: util.ToCamelCase(s.Name),
ClientMethods: s.ClientMethods,
BaseDomain: baseDomain,
Expand Down
1 change: 1 addition & 0 deletions cmd/hz/generator/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ type HttpPackageGenerator struct {
UseDir string
ClientDir string
IdlClientDir string
ForceClientDir string
NeedModel bool
HandlerByMethod bool
BaseDomain string
Expand Down
13 changes: 12 additions & 1 deletion cmd/hz/generator/package_tpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ func parseRequestHeader(c *cli, r *request) error {
hdr[k] = append(hdr[k], r.header[k]...)
}
if len(r.formParam) != 0 && len(r.fileParam) != 0 {
if len(r.formParam) != 0 || len(r.fileParam) != 0 {
hdr.Add(hdrContentTypeKey, formContentType)
}
Expand Down Expand Up @@ -857,6 +857,7 @@ package {{.PackageName}}
import (
"context"
"fmt"
"github.com/cloudwego/hertz/pkg/common/config"
"github.com/cloudwego/hertz/pkg/protocol"
Expand All @@ -865,6 +866,11 @@ import (
{{- end}}
)
// unused protection
var (
_ = fmt.Formatter(nil)
)
type Client interface {
{{range $_, $MethodInfo := .ClientMethods}}
{{$MethodInfo.Name}}(context context.Context, req *{{$MethodInfo.RequestTypeName}}, reqOpt ...config.RequestOption) (resp *{{$MethodInfo.ReturnTypeName}}, rawResponse *protocol.Response, err error)
Expand Down Expand Up @@ -922,6 +928,11 @@ func (s *{{$.ServiceName}}Client) {{$MethodInfo.Name}}(context context.Context,
var defaultClient, _ = New{{.ServiceName}}Client("{{.BaseDomain}}")
func ConfigDefaultClient(ops ...Option) (err error) {
defaultClient, err = NewHertzClient("{{.BaseDomain}}", ops...)
return
}
{{range $_, $MethodInfo := .ClientMethods}}
func {{$MethodInfo.Name}}(context context.Context, req *{{$MethodInfo.RequestTypeName}}, reqOpt ...config.RequestOption) (resp *{{$MethodInfo.ReturnTypeName}}, rawResponse *protocol.Response, err error) {
return defaultClient.{{$MethodInfo.Name}}(context, req, reqOpt...)
Expand Down
33 changes: 30 additions & 3 deletions cmd/hz/protobuf/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/jhump/protoreflect/desc"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/descriptorpb"
)

Expand Down Expand Up @@ -249,41 +250,67 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, gen *protogen
hasFormAnnotation bool
)
for _, f := range inputType.Fields {
hasAnnotation := false
isStringFieldType := false
if f.Desc.Kind() == protoreflect.StringKind {
isStringFieldType = true
}
if proto.HasExtension(f.Desc.Options(), api.E_Query) {
hasAnnotation = true
queryAnnos := proto.GetExtension(f.Desc.Options(), api.E_Query)
val := queryAnnos.(string)
clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName)
}

if proto.HasExtension(f.Desc.Options(), api.E_Path) {
hasAnnotation = true
pathAnnos := proto.GetExtension(f.Desc.Options(), api.E_Path)
val := pathAnnos.(string)
clientMethod.PathParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName)
if isStringFieldType {
clientMethod.PathParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName)
} else {
clientMethod.PathParamsCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", val, f.GoName)
}
}

if proto.HasExtension(f.Desc.Options(), api.E_Header) {
hasAnnotation = true
headerAnnos := proto.GetExtension(f.Desc.Options(), api.E_Header)
val := headerAnnos.(string)
clientMethod.HeaderParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName)
if isStringFieldType {
clientMethod.HeaderParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName)
} else {
clientMethod.HeaderParamsCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", val, f.GoName)
}
}

if proto.HasExtension(f.Desc.Options(), api.E_Form) {
hasAnnotation = true
formAnnos := proto.GetExtension(f.Desc.Options(), api.E_Form)
hasFormAnnotation = true
val := formAnnos.(string)
clientMethod.FormValueCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName)
if isStringFieldType {
clientMethod.FormValueCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName)
} else {
clientMethod.FormValueCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", val, f.GoName)
}
}

if proto.HasExtension(f.Desc.Options(), api.E_Body) {
hasAnnotation = true
hasBodyAnnotation = true
}

if proto.HasExtension(f.Desc.Options(), api.E_FileName) {
hasAnnotation = true
fileAnnos := proto.GetExtension(f.Desc.Options(), api.E_FileName)
hasFormAnnotation = true
val := fileAnnos.(string)
clientMethod.FormFileCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName)
}
if !hasAnnotation && strings.EqualFold(clientMethod.HTTPMethod, "get") {
clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", f.GoName, f.GoName)
}
}
clientMethod.BodyParamsCode = meta.SetBodyParam
if hasBodyAnnotation && hasFormAnnotation {
Expand Down
1 change: 1 addition & 0 deletions cmd/hz/protobuf/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,7 @@ func (plugin *Plugin) genHttpPackage(ast *descriptorpb.FileDescriptorProto, deps
HandlerByMethod: args.HandlerByMethod,
CmdType: args.CmdType,
IdlClientDir: plugin.IdlClientDir,
ForceClientDir: args.ForceClientDir,
BaseDomain: args.BaseDomain,
}

Expand Down
46 changes: 24 additions & 22 deletions cmd/hz/thrift/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,59 +181,61 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, p *parser.Typ
hasFormAnnotation bool
)
for _, field := range st.Fields() {
rwctx, err := thriftgoUtil.MkRWCtx(thriftgoUtil.RootScope(), field)
if err != nil {
fmt.Errorf("can not get field info for %s", field.Name)
hasAnnotation := false
isStringFieldType := false
if field.GetType().String() == "string" {
isStringFieldType = true
}
if anno := getAnnotation(field.Annotations, AnnotationQuery); len(anno) > 0 {
hasAnnotation = true
query := anno[0]
if rwctx.IsPointer {
clientMethod.QueryParamsCode += fmt.Sprintf("%q: *req.%v,\n", query, field.GoName().String())
} else {
clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.%v,\n", query, field.GoName().String())
}
clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", query, field.GoName().String())
}

if anno := getAnnotation(field.Annotations, AnnotationPath); len(anno) > 0 {
hasAnnotation = true
path := anno[0]
if rwctx.IsPointer {
clientMethod.PathParamsCode += fmt.Sprintf("%q: *req.%v,\n", path, field.GoName().String())
if isStringFieldType {
clientMethod.PathParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", path, field.GoName().String())
} else {
clientMethod.PathParamsCode += fmt.Sprintf("%q: req.%v,\n", path, field.GoName().String())
clientMethod.PathParamsCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", path, field.GoName().String())
}
}

if anno := getAnnotation(field.Annotations, AnnotationHeader); len(anno) > 0 {
hasAnnotation = true
header := anno[0]
if rwctx.IsPointer {
clientMethod.HeaderParamsCode += fmt.Sprintf("%q: *req.%v,\n", header, field.GoName().String())
if isStringFieldType {
clientMethod.HeaderParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", header, field.GoName().String())
} else {
clientMethod.HeaderParamsCode += fmt.Sprintf("%q: req.%v,\n", header, field.GoName().String())
clientMethod.HeaderParamsCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", header, field.GoName().String())
}
}

if anno := getAnnotation(field.Annotations, AnnotationForm); len(anno) > 0 {
hasAnnotation = true
form := anno[0]
hasFormAnnotation = true
if rwctx.IsPointer {
clientMethod.FormValueCode += fmt.Sprintf("%q: *req.%v,\n", form, field.GoName().String())
if isStringFieldType {
clientMethod.FormValueCode += fmt.Sprintf("%q: req.Get%s(),\n", form, field.GoName().String())
} else {
clientMethod.FormValueCode += fmt.Sprintf("%q: req.%v,\n", form, field.GoName().String())
clientMethod.FormValueCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", form, field.GoName().String())
}
}

if anno := getAnnotation(field.Annotations, AnnotationBody); len(anno) > 0 {
hasAnnotation = true
hasBodyAnnotation = true
}

if anno := getAnnotation(field.Annotations, AnnotationFileName); len(anno) > 0 {
hasAnnotation = true
fileName := anno[0]
hasFormAnnotation = true
if rwctx.IsPointer {
clientMethod.FormFileCode += fmt.Sprintf("%q: *req.%v,\n", fileName, field.GoName().String())
} else {
clientMethod.FormFileCode += fmt.Sprintf("%q: req.%v,\n", fileName, field.GoName().String())
}
clientMethod.FormFileCode += fmt.Sprintf("%q: req.Get%s(),\n", fileName, field.GoName().String())
}
if !hasAnnotation && strings.EqualFold(clientMethod.HTTPMethod, "get") {
clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", field.GoName().String(), field.GoName().String())
}
}
clientMethod.BodyParamsCode = meta.SetBodyParam
Expand Down
1 change: 1 addition & 0 deletions cmd/hz/thrift/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ func (plugin *Plugin) Run() int {
HandlerByMethod: args.HandlerByMethod,
CmdType: args.CmdType,
IdlClientDir: util.SubDir(modelDir, pkgInfo.Package),
ForceClientDir: args.ForceClientDir,
BaseDomain: args.BaseDomain,
}
if args.ModelBackend != "" {
Expand Down

0 comments on commit dae22e5

Please sign in to comment.