Skip to content

Commit

Permalink
Merge pull request #29 from hexdigest/render_imports
Browse files Browse the repository at this point in the history
Render imports by @recht
  • Loading branch information
hexdigest committed Sep 22, 2020
2 parents e3ff801 + bb53266 commit 8f89b01
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 7 deletions.
55 changes: 50 additions & 5 deletions generator/generator.go
Expand Up @@ -3,6 +3,7 @@ package generator
import (
"bytes"
"path/filepath"
"sort"
"strings"

"go/ast"
Expand Down Expand Up @@ -35,7 +36,46 @@ type TemplateInputs struct {
// Interface information for template
Interface TemplateInputInterface
// Vars additional vars to pass to the template, see Options.Vars
Vars map[string]interface{}
Vars map[string]interface{}
Imports []string
}

// Import generates an import statement using a list of imports from the source file
// along with the ones from the template itself
func (t TemplateInputs) Import(imports ...string) string {
allImports := make(map[string]struct{}, len(imports)+len(t.Imports))

for _, i := range t.Imports {
allImports[strings.TrimSpace(i)] = struct{}{}
}

for _, i := range imports {
if len(i) == 0 {
continue
}

i = strings.TrimSpace(i)

if i[len(i)-1] != '"' {
i += `"`
}

if i[0] != '"' {
i = `"` + i
}

allImports[i] = struct{}{}
}

out := make([]string, 0, len(allImports))

for i := range allImports {
out = append(out, i)
}

sort.Strings(out)

return "import (\n" + strings.Join(out, "\n") + ")\n"
}

// TemplateInputInterface subset of interface information used for template generation
Expand Down Expand Up @@ -131,8 +171,12 @@ func NewGenerator(options Options) (*Generator, error) {
if srcPackage.PkgPath == dstPackage.PkgPath {
interfaceType = options.InterfaceName
srcPackageAST.Name = ""
} else if options.SourcePackageAlias != "" {
srcPackageAST.Name = options.SourcePackageAlias
} else {
if options.SourcePackageAlias != "" {
srcPackageAST.Name = options.SourcePackageAlias
}

options.Imports = append(options.Imports, `"`+srcPackage.PkgPath+`"`)
}

methods, imports, err := findInterface(fs, srcPackageAST, options.InterfaceName)
Expand All @@ -150,7 +194,7 @@ func NewGenerator(options Options) (*Generator, error) {
}
}

options.Imports = makeImports(imports)
options.Imports = append(options.Imports, makeImports(imports)...)

return &Generator{
Options: options,
Expand Down Expand Up @@ -219,7 +263,8 @@ func (g Generator) Generate(w io.Writer) error {
Type: g.interfaceType,
Methods: g.methods,
},
Vars: g.Options.Vars,
Imports: g.Options.Imports,
Vars: g.Options.Vars,
})
if err != nil {
return err
Expand Down
34 changes: 32 additions & 2 deletions generator/generator_test.go
Expand Up @@ -484,7 +484,7 @@ func TestGenerator_Generate(t *testing.T) {
tests := []struct {
name string
init func(t minimock.Tester) Generator
inspect func(r Generator, t *testing.T) //inspects Generator after execution of Generate
inspect func(r Generator, w io.Writer, t *testing.T) //inspects Generator after execution of Generate

args func(t minimock.Tester) args

Expand Down Expand Up @@ -551,6 +551,36 @@ func TestGenerator_Generate(t *testing.T) {
},
wantErr: false,
},
{
name: "imports can be generated",
init: func(t minimock.Tester) Generator {
return Generator{
Options: Options{
Imports: []string{`"github.com/pkg/errors"`, `"github.com/sirupsen/logrus"`},
},
headerTemplate: template.Must(template.New("header").Parse("package success\n")),
bodyTemplate: template.Must(template.New("body").Parse(`
{{.Import "github.com/sirupsen/logrus" }}
func test(l *logrus.Logger) {}
`)),
}
},
args: func(t minimock.Tester) args {
return args{
w: bytes.NewBuffer([]byte{}),
}
},
inspect: func(_ Generator, w io.Writer, t *testing.T) {
assert.Equal(t, `package success
import (
"github.com/sirupsen/logrus"
)
func test(l *logrus.Logger) {}
`, w.(*bytes.Buffer).String())
},
},
}

for _, tt := range tests {
Expand All @@ -564,7 +594,7 @@ func TestGenerator_Generate(t *testing.T) {
err := receiver.Generate(tArgs.w)

if tt.inspect != nil {
tt.inspect(receiver, t)
tt.inspect(receiver, tArgs.w, t)
}

if tt.wantErr {
Expand Down

0 comments on commit 8f89b01

Please sign in to comment.