This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
/
pflag_provider.go
90 lines (74 loc) · 2.25 KB
/
pflag_provider.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
package api
import (
"bytes"
"fmt"
"go/types"
"io/ioutil"
"os"
"time"
"github.com/ernesto-jimenez/gogen/imports"
goimports "golang.org/x/tools/imports"
)
type PFlagProvider struct {
typeName string
pkg *types.Package
fields []FieldInfo
}
// Adds any needed imports for types not directly declared in this package.
func (p PFlagProvider) Imports() map[string]string {
imp := imports.New(p.pkg.Name())
for _, m := range p.fields {
imp.AddImportsFrom(m.Typ)
}
return imp.Imports()
}
// Evaluates the main code file template and writes the output to outputFilePath
func (p PFlagProvider) WriteCodeFile(outputFilePath string) error {
buf := bytes.Buffer{}
err := p.generate(GenerateCodeFile, &buf, outputFilePath)
if err != nil {
return fmt.Errorf("error generating code, Error: %v. Source: %v", err, buf.String())
}
return p.writeToFile(&buf, outputFilePath)
}
// Evaluates the test code file template and writes the output to outputFilePath
func (p PFlagProvider) WriteTestFile(outputFilePath string) error {
buf := bytes.Buffer{}
err := p.generate(GenerateTestFile, &buf, outputFilePath)
if err != nil {
return fmt.Errorf("error generating code, Error: %v. Source: %v", err, buf.String())
}
return p.writeToFile(&buf, outputFilePath)
}
func (p PFlagProvider) writeToFile(buffer *bytes.Buffer, fileName string) error {
return ioutil.WriteFile(fileName, buffer.Bytes(), os.ModePerm)
}
// Evaluates the generator and writes the output to buffer. targetFileName is used only to influence how imports are
// generated/optimized.
func (p PFlagProvider) generate(generator func(buffer *bytes.Buffer, info TypeInfo) error, buffer *bytes.Buffer, targetFileName string) error {
info := TypeInfo{
Name: p.typeName,
Fields: p.fields,
Package: p.pkg.Name(),
Timestamp: time.Now(),
Imports: p.Imports(),
}
if err := generator(buffer, info); err != nil {
return err
}
// Update imports
newBytes, err := goimports.Process(targetFileName, buffer.Bytes(), nil)
if err != nil {
return err
}
buffer.Reset()
_, err = buffer.Write(newBytes)
return err
}
func newPflagProvider(pkg *types.Package, typeName string, fields []FieldInfo) PFlagProvider {
return PFlagProvider{
typeName: typeName,
pkg: pkg,
fields: fields,
}
}