forked from whyrusleeping/cbor-gen
/
writefile.go
146 lines (128 loc) · 3.99 KB
/
writefile.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
package typegen
import (
"bytes"
"go/format"
"os"
"sort"
"golang.org/x/xerrors"
)
// WriteTupleFileEncodersToFile generates array backed MarshalCBOR and UnmarshalCBOR implementations for the
// given types in the specified file, with the specified package name.
//
// The MarshalCBOR and UnmarshalCBOR implementations will marshal/unmarshal each type's fields as a
// fixed-length CBOR array of field values.
func WriteTupleEncodersToFile(fname, pkg string, flattenEmbeddedStruct bool,
fieldOrder []string, types ...interface{}) error {
buf := new(bytes.Buffer)
typeInfos := make([]*GenTypeInfo, len(types))
embeddedByPointerStructsInfos := make([]*[]string, len(types))
for i, t := range types {
gti, embeddedByPointerStructs, err := ParseTypeInfo(t, flattenEmbeddedStruct)
if err != nil {
return xerrors.Errorf("failed to parse type info: %w", err)
}
if fieldOrder != nil {
ordered := make([]Field, 0, len(gti.Fields))
fieldMap := map[string]*Field{}
for i, f := range gti.Fields {
fieldMap[f.Name] = >i.Fields[i]
}
// First the fields specified in `fieldOrder`
for _, name := range fieldOrder {
if f, ok := fieldMap[name]; ok {
// Mark as picked
delete(fieldMap, name)
ordered = append(ordered, *f)
}
}
// The remaining fields
for _, f := range gti.Fields {
if _, ok := fieldMap[f.Name]; ok {
ordered = append(ordered, f)
}
}
// Assert that len(ordered) matches the field count, should never panic
if len(ordered) != len(gti.Fields) {
panic("Bug: len(ordered) != len(gti.Fields)")
}
// Replace gti.Fields with ordered fields
gti.Fields = ordered
}
typeInfos[i] = gti
if flattenEmbeddedStruct {
embeddedByPointerStructsInfos[i] = embeddedByPointerStructs
}
}
if err := PrintHeaderAndUtilityMethods(buf, pkg, typeInfos); err != nil {
return xerrors.Errorf("failed to write header: %w", err)
}
for i, t := range typeInfos {
if err := GenTupleEncodersForType(t, flattenEmbeddedStruct,
embeddedByPointerStructsInfos[i], buf); err != nil {
return xerrors.Errorf("failed to generate encoders: %w", err)
}
}
data, err := format.Source(buf.Bytes())
if err != nil {
return err
}
fi, err := os.Create(fname)
if err != nil {
return xerrors.Errorf("failed to open file: %w", err)
}
_, err = fi.Write(data)
if err != nil {
_ = fi.Close()
return err
}
_ = fi.Close()
return nil
}
// WriteMapFileEncodersToFile generates map backed MarshalCBOR and UnmarshalCBOR implementations for
// the given types in the specified file, with the specified package name.
//
// The MarshalCBOR and UnmarshalCBOR implementations will marshal/unmarshal each type's fields as a
// map of field names to field values.
func WriteMapEncodersToFile(fname, pkg string, flattenEmbeddedStruct bool,
types ...interface{}) error {
buf := new(bytes.Buffer)
typeInfos := make([]*GenTypeInfo, len(types))
embeddedByPointerStructsInfos := make([]*[]string, len(types))
for i, t := range types {
gti, embeddedByPointerStructs, err := ParseTypeInfo(t, flattenEmbeddedStruct)
if err != nil {
return xerrors.Errorf("failed to parse type info: %w", err)
}
sort.Slice(gti.Fields, func(i, j int) bool {
return mapKeySort_RFC7049Less(gti.Fields[i].Name, gti.Fields[j].Name)
})
typeInfos[i] = gti
if flattenEmbeddedStruct {
embeddedByPointerStructsInfos[i] = embeddedByPointerStructs
}
}
if err := PrintHeaderAndUtilityMethods(buf, pkg, typeInfos); err != nil {
return xerrors.Errorf("failed to write header: %w", err)
}
for i, t := range typeInfos {
if err := GenMapEncodersForType(t, flattenEmbeddedStruct,
embeddedByPointerStructsInfos[i], buf); err != nil {
return xerrors.Errorf("failed to generate encoders: %w", err)
}
}
data, err := format.Source(buf.Bytes())
if err != nil {
return err
}
fi, err := os.Create(fname)
if err != nil {
return xerrors.Errorf("failed to open file: %w", err)
}
_, err = fi.Write(data)
if err != nil {
_ = fi.Close()
return err
}
_ = fi.Close()
return nil
}