-
Notifications
You must be signed in to change notification settings - Fork 5
/
main.go
129 lines (112 loc) · 3.02 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
package main
// https://github.com/deepmap/oapi-codegen/pull/707
import (
"embed"
"flag"
"io/fs"
"log"
"os"
"path"
"strings"
"text/template"
"github.com/deepmap/oapi-codegen/pkg/codegen"
"github.com/deepmap/oapi-codegen/pkg/util"
"github.com/getkin/kin-openapi/openapi3"
"gopkg.in/yaml.v3"
)
type configuration struct {
codegen.Configuration `yaml:",inline"`
// OutputFile is the filename to output.
OutputFile string `yaml:"output,omitempty"`
}
//go:embed oapi-templates
var templates embed.FS
func main() {
log.SetFlags(0)
var cfgPath, modelsPkg string
flag.StringVar(&cfgPath, "config", "", "path to config file")
flag.StringVar(&modelsPkg, "models-pkg", "models", "package containing models")
flag.Parse()
if cfgPath == "" {
log.Fatal("--config is required")
}
if flag.NArg() < 1 {
log.Fatal("Please specify a path to an OpenAPI 3.0 spec file")
}
// loading specification
input := flag.Arg(0)
spec, err := util.LoadSwagger(input)
if err != nil {
log.Fatalf("error loading openapi specification: %v", err)
}
// will fail on separated yamls
// err = spec.Validate(context.Background())
// if err != nil {
// log.Fatalf("error validating openapi specification: %v", err)
// }
// loading configuration
cfgdata, err := os.ReadFile(cfgPath)
if err != nil {
log.Fatalf("error reading config file: %s", err)
}
var cfg configuration
err = yaml.Unmarshal(cfgdata, &cfg)
if err != nil {
log.Fatalf("error unmarshaling config %v", err)
}
// generating output
output, err := generate(spec, cfg.Configuration, templates, modelsPkg)
if err != nil {
log.Fatalf("error generating code: %v", err)
}
// writing output to file
outFile, err := os.Create(cfg.OutputFile)
if err != nil {
log.Fatalf("error creating output file: %v", err)
}
_, err = outFile.WriteString(output)
if err != nil {
log.Fatalf("error writing output file: %v", err)
}
outFile.Close()
}
func generate(spec *openapi3.T, config codegen.Configuration, templates embed.FS, modelsPkg string) (string, error) {
var err error
config, err = addTemplateOverrides(config, templates)
if err != nil {
return "", err
}
// include other template functions, if any
templateFunctions := template.FuncMap{
"modelsPkg": func() string {
return modelsPkg + "."
},
}
for k, v := range templateFunctions {
codegen.TemplateFunctions[k] = v
}
return codegen.Generate(spec, config)
}
func addTemplateOverrides(config codegen.Configuration, templates embed.FS) (codegen.Configuration, error) {
overrides := config.OutputOptions.UserTemplates
if overrides == nil {
overrides = make(map[string]string)
}
err := fs.WalkDir(templates, ".", func(p string, d fs.DirEntry, err error) error {
if !d.IsDir() {
if err != nil {
return err
}
f, err := templates.ReadFile(p)
if err != nil {
return err
}
name := strings.TrimSuffix(p, path.Ext(p)) + ".tmpl"
name = strings.Join(strings.Split(name, "/")[1:], "/")
overrides[name] = string(f)
}
return nil
})
config.OutputOptions.UserTemplates = overrides
return config, err
}