Skip to content

Commit

Permalink
fix(*): fix issue where generated client does not import third-party …
Browse files Browse the repository at this point in the history
…packages properly (#361)
  • Loading branch information
Xinzhao Xu authored Oct 20, 2020
1 parent dcb2879 commit 5d4d1b0
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 62 deletions.
4 changes: 2 additions & 2 deletions cmd/nirvana/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,13 @@ func (o *apiOptions) Run(cmd *cobra.Command, args []string) error {
}

files := map[string][]byte{}
for _, s := range swaggers {
for filename, s := range swaggers {
data, err := json.MarshalIndent(s, "", " ")
if err != nil {
return err
}

files[s.Info.Version] = data
files[filename] = data
}

if o.Output != "" {
Expand Down
71 changes: 42 additions & 29 deletions utils/generators/golang/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,25 +62,27 @@ func (g *Generator) Generate() (map[string][]byte, error) {
return nil, err
}
codes := make(map[string][]byte)
versions := make([]string, 0, len(definitions))
for version, defs := range definitions {
versions = append(versions, version)
helper, err := newHelper(g.rootPkg, defs)
versions := make([]utils.Version, 0, len(definitions))
for _, d := range definitions {
versions = append(versions, d.Version)
helper, err := newHelper(g.rootPkg, d.Defs)
if err != nil {
return nil, err
}
// all lower case string
packageName := d.Version.Module + d.Version.Name
types, imports := helper.Types()
typeCodes, err := g.typeCodes(version, types, imports)
typeCodes, err := g.typeCodes(packageName, types, imports)
if err != nil {
return nil, err
}
functions, imports := helper.Functions()
functionCodes, err := g.functionCodes(version, functions, imports)
functionCodes, err := g.functionCodes(packageName, functions, imports)
if err != nil {
return nil, err
}
codes[version+"/types"] = typeCodes
codes[version+"/client"] = functionCodes
codes[packageName+"/types"] = typeCodes
codes[packageName+"/client"] = functionCodes
}
client, err := g.aggregationClientCode(versions)
if err != nil {
Expand All @@ -98,11 +100,14 @@ func (g *Generator) typeCodes(version string, types []Type, imports []string) ([
_ = err
}

writeln("import (")
for _, pkg := range imports {
writeln(pkg)
if len(imports) > 0 {
writeln("import (")
for _, pkg := range imports {
writeln(pkg)
}
writeln(")")
}
writeln(")")

for _, typ := range types {
writeln("")
writeln(string(typ.Generate()))
Expand All @@ -116,7 +121,7 @@ func (g *Generator) functionCodes(version string, functions []function, imports
package {{ .Version }}
import (
context "context"
"context"
{{- range .Imports }}
{{.}}
Expand Down Expand Up @@ -205,78 +210,86 @@ func (c *Client) {{ .Name }}(ctx context.Context{{ range .Parameters }},{{ .Prop
}

type versionedPackage struct {
Alias string
Version string
Path string
Function string
}

func (g *Generator) aggregationClientCode(versions []string) ([]byte, error) {
func (g *Generator) aggregationClientCode(versions []utils.Version) ([]byte, error) {
buf := bytes.NewBuffer(nil)
template, err := template.New("codes").Parse(`
package {{ .PackageName }}
import (
{{ range .Pakcages }}
{{ .Version }} "{{ .Path }}"
{{ end }}
"{{ .Path }}"
{{- end }}
rest "{{ .Rest }}"
)
// Interface describes a versioned client.
type Interface interface {
{{- range .Pakcages }}
// {{ .Function }} returns {{ .Version }} client.
{{ .Function }}() {{ .Version }}.Interface
// {{ .Function }} returns {{ .Alias }} client.
{{ .Function }}() {{ .Alias }}.Interface
{{- end }}
}
// Client contains versioned clients.
type Client struct {
{{ range .Pakcages }}
{{ .Version }} *{{ .Version }}.Client
{{ end }}
{{ .Version }} *{{ .Alias }}.Client
{{- end }}
}
// NewClient creates a new client.
func NewClient(cfg *rest.Config) (Interface, error) {
c := &Client{}
var err error
{{ range .Pakcages }}
c.{{ .Version }}, err = {{ .Version }}.NewClient(cfg)
c.{{ .Version }}, err = {{ .Alias }}.NewClient(cfg)
if err != nil {
return nil, err
}
{{ end }}
{{ end -}}
return c, nil
}
// MustNewClient creates a new client or panic if an error occurs.
func MustNewClient(cfg *rest.Config) Interface {
return &Client{
{{- range .Pakcages }}
{{ .Version }}: {{ .Version }}.MustNewClient(cfg),
{{ .Version }}: {{ .Alias }}.MustNewClient(cfg),
{{- end }}
}
}
{{ range .Pakcages }}
// {{ .Function }} returns a versioned client.
func (c *Client) {{ .Function }}() {{ .Version }}.Interface {
func (c *Client) {{ .Function }}() {{ .Alias }}.Interface {
return c.{{ .Version }}
}
{{ end }}
`)
if err != nil {
return nil, err
}
packages := []versionedPackage{}
packages := make([]versionedPackage, 0, len(versions))
for _, version := range versions {
alias := version.Module + version.Name
var v string
if version.Module != "" {
v = version.Module + strings.Title(version.Name)
} else {
v = version.Name
}
packages = append(packages, versionedPackage{
Version: version,
Path: path.Join(g.pkg, version),
Function: strings.Title(version),
Alias: alias,
Version: v,
Path: path.Join(g.pkg, alias),
Function: strings.Title(version.Module) + strings.Title(version.Name),
})
}
err = template.Execute(buf, map[string]interface{}{
Expand Down
65 changes: 43 additions & 22 deletions utils/generators/golang/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@ package golang

import (
"fmt"
"go/token"
"path"
"reflect"
"regexp"
"sort"
"strconv"
"strings"

"github.com/caicloud/nirvana/definition"
"github.com/caicloud/nirvana/service"
"github.com/caicloud/nirvana/utils/api"
"github.com/caicloud/nirvana/utils/generators/utils"
)

// Type abstracts common ability from type declarations.
Expand Down Expand Up @@ -147,8 +148,10 @@ type parameterExtension struct {
}

type functionParameter struct {
Source string
Name string
Source string
// Name is the name in the Parameter of the Definition, used in the generated function body.
Name string
// ProposedName is the parameter name of the generated function.
ProposedName string
Typ string
Extensions []parameterExtension
Expand Down Expand Up @@ -189,8 +192,8 @@ func newHelper(rootPkg string, definitions *api.Definitions) (*helper, error) {

// Types returns types which is required to generate.
func (h *helper) Types() ([]Type, []string) {
types := []Type{}
generatedTypes := []*api.Type{}
types := make([]Type, 0, len(h.definitions.Types))
generatedTypes := make([]*api.Type, 0, len(h.definitions.Types))

for name, typ := range h.definitions.Types {
if typ.Kind == reflect.Func || typ.PkgPath == "" ||
Expand Down Expand Up @@ -270,7 +273,7 @@ func (h *helper) packages(types []*api.Type, extended bool) []string {
pkgMap[pkg] = true
}
}
results := []string{}
results := make([]string, 0, len(pkgMap))
for pkg := range pkgMap {
alias := h.namer.Alias(pkg)
results = append(results, fmt.Sprintf(`%s "%s"`, alias, pkg))
Expand All @@ -280,14 +283,18 @@ func (h *helper) packages(types []*api.Type, extended bool) []string {

// pkgs generates a list of imported packages without aliases.
func (h *helper) pkgs(typ *api.Type, extended bool) []string {
switch typ.Kind {
case reflect.Array, reflect.Slice, reflect.Ptr:
return h.pkgs(h.definitions.Types[typ.Elem], extended)
case reflect.Map:
pkgs := h.pkgs(h.definitions.Types[typ.Key], extended)
return append(pkgs, h.pkgs(h.definitions.Types[typ.Elem], extended)...)
}

// handle third-party types and cases where the referenced type is an array/map alias
// eg:
// import (
// "go.mongodb.org/mongo-driver/bson/primitive"
// )
//
// type XXX struct {
// ID primitive.ObjectID `json:"_id"`
// }
// definition of ObjectID:
// type ObjectID [12]byte
// for this type, you don't need to import the package of its child, just the package itself
if typ.PkgPath != "" {
index := strings.LastIndex(typ.PkgPath, "/vendor/")
if index >= 0 ||
Expand All @@ -298,22 +305,34 @@ func (h *helper) pkgs(typ *api.Type, extended bool) []string {
}
return []string{pkg}
}

if extended && typ.Kind == reflect.Struct {
pkgs := []string{}
pkgs := make([]string, 0, len(typ.Fields))
for _, field := range typ.Fields {
pkgs = append(pkgs, h.pkgs(h.definitions.Types[field.Type], extended)...)
}
return pkgs
}
return nil
}
return []string{}

// handle normal array/map definitions, eg:
// Metas []*v1.ObjectMeta `json:"metas"`
// Objects map[string]*v1.Object `json:"objects"`
switch typ.Kind {
case reflect.Array, reflect.Slice, reflect.Ptr:
return h.pkgs(h.definitions.Types[typ.Elem], extended)
case reflect.Map:
return append(h.pkgs(h.definitions.Types[typ.Key], extended), h.pkgs(h.definitions.Types[typ.Elem], extended)...)
}
return nil
}

// Functions returns functions which is required to generate.
func (h *helper) Functions() ([]function, []string) {
functionNames := map[string]int{}
functions := []function{}
types := []*api.Type{}
functions := make([]function, 0, len(h.definitions.Definitions))
types := make([]*api.Type, 0, len(h.definitions.Definitions))
for path, defs := range h.definitions.Definitions {
for _, def := range defs {
fn := function{
Expand All @@ -324,7 +343,7 @@ func (h *helper) Functions() ([]function, []string) {
// The priority of summary is higher than original function name.
if def.Summary != "" {
// Remove invalid chars and regard as function name.
fn.Name = nameReplacer.ReplaceAllString(def.Summary, "")
fn.Name = utils.NameReplacer.ReplaceAllString(def.Summary, "")
}

if fn.Name == "" {
Expand Down Expand Up @@ -454,19 +473,21 @@ type nameContainer struct {
namer *typeNamer
}

var nameReplacer = regexp.MustCompile(`[^a-zA-Z0-9]`)

func (n *nameContainer) proposeName(name string, typ api.TypeName) string {
if name == "" {
name = n.deconstruct(typ)
}
name = nameReplacer.ReplaceAllString(name, "")
name = utils.NameReplacer.ReplaceAllString(name, "")
if name == "" {
name = "temp"
}
if name[0] >= 'A' && name[0] <= 'Z' {
name = string(name[0]|0x20) + name[1:]
}
// name may be `type` etc.
if token.Lookup(name).IsKeyword() {
name += "_"
}
index := n.names[name]
if index > 0 {
name += strconv.Itoa(index)
Expand Down
14 changes: 10 additions & 4 deletions utils/generators/swagger/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,11 @@ func NewGenerator(
}

// Generate generates swagger specifications.
func (g *Generator) Generate() ([]spec.Swagger, error) {
func (g *Generator) Generate() (map[string]spec.Swagger, error) {
g.parseSchemas()
g.parsePaths()

swaggers := make([]spec.Swagger, 0, len(g.config.Versions))
swaggers := make(map[string]spec.Swagger, len(g.config.Versions))
for _, version := range g.config.Versions {
title := fmt.Sprintln(g.config.Project, "APIs")
description := g.config.Description
Expand All @@ -120,7 +120,13 @@ func (g *Generator) Generate() ([]spec.Swagger, error) {
schemes, hosts, contacts,
version.PathRules,
)
swaggers = append(swaggers, *swagger)
var filename string
if version.Module != "" {
filename = strings.ToLower(version.Module) + "." + strings.ToLower(version.Name)
} else {
filename = strings.ToLower(version.Name)
}
swaggers[filename] = *swagger
}

if len(swaggers) <= 0 {
Expand All @@ -129,7 +135,7 @@ func (g *Generator) Generate() ([]spec.Swagger, error) {
g.config.Schemes, g.config.Hosts, g.config.Contacts,
nil,
)
swaggers = append(swaggers, *swagger)
swaggers["unknown"] = *swagger
}
return swaggers, nil
}
Expand Down
Loading

0 comments on commit 5d4d1b0

Please sign in to comment.