forked from dave/jennifer
/
jen.go
97 lines (92 loc) · 2.18 KB
/
jen.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
// Package jen is a code generator for Go
package jen
import (
"bytes"
"fmt"
"go/format"
"io"
"io/ioutil"
"sort"
"strconv"
)
// Code represents an item of code that can be rendered.
type Code interface {
render(f *File, w io.Writer, s *Statement) error
isNull(f *File) bool
}
// Save renders the file and saves to the filename provided.
func (f *File) Save(filename string) error {
buf := &bytes.Buffer{}
if err := f.Render(buf); err != nil {
return err
}
if err := ioutil.WriteFile(filename, buf.Bytes(), 0644); err != nil {
return err
}
return nil
}
// Render renders the file to the provided writer.
func (f *File) Render(w io.Writer) error {
body := &bytes.Buffer{}
if err := f.render(f, body, nil); err != nil {
return err
}
source := &bytes.Buffer{}
if f.comments != nil {
for _, c := range f.comments {
if err := Comment(c).render(f, source, nil); err != nil {
return err
}
if _, err := fmt.Fprint(source, "\n"); err != nil {
return err
}
}
}
if _, err := fmt.Fprintf(source, "package %s\n\n", f.name); err != nil {
return err
}
if err := f.renderImports(source); err != nil {
return err
}
if _, err := source.Write(body.Bytes()); err != nil {
return err
}
formatted, err := format.Source(source.Bytes())
if err != nil {
return fmt.Errorf("Error %s while formatting source:\n%s", err, source.String())
}
_, err = w.Write(formatted)
if err != nil {
return err
}
return nil
}
func (f *File) renderImports(source io.Writer) error {
if len(f.imports) == 1 {
for path, alias := range f.imports {
if _, err := fmt.Fprintf(source, "import %s %s\n\n", alias, strconv.Quote(path)); err != nil {
return err
}
}
} else if len(f.imports) > 1 {
if _, err := fmt.Fprint(source, "import (\n"); err != nil {
return err
}
// We must sort the imports to ensure repeatable
// source.
paths := []string{}
for path := range f.imports {
paths = append(paths, path)
}
sort.Strings(paths)
for _, path := range paths {
if _, err := fmt.Fprintf(source, "%s %s\n", f.imports[path], strconv.Quote(path)); err != nil {
return err
}
}
if _, err := fmt.Fprint(source, ")\n\n"); err != nil {
return err
}
}
return nil
}