forked from jschaf/pggen
-
Notifications
You must be signed in to change notification settings - Fork 0
/
type_resolver.go
172 lines (153 loc) · 4.65 KB
/
type_resolver.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
package golang
import (
"fmt"
"github.com/leg100/pggen/internal/casing"
"github.com/leg100/pggen/internal/codegen/golang/gotype"
"github.com/leg100/pggen/internal/pg"
"strconv"
"strings"
)
// TypeResolver handles the mapping between Postgres and Go types.
type TypeResolver struct {
caser casing.Caser
overrides map[string]string
}
func NewTypeResolver(c casing.Caser, overrides map[string]string) TypeResolver {
overs := make(map[string]string, len(overrides))
for k, v := range overrides {
for _, alias := range listAliases(k) {
overs[alias] = v
}
}
return TypeResolver{caser: c, overrides: overs}
}
// Resolve maps a Postgres type to a Go type.
func (tr TypeResolver) Resolve(pgt pg.Type, nullable bool, pkgPath string) (gotype.Type, error) {
// Custom user override.
if goType, ok := tr.overrides[pgt.String()]; ok {
opaque := gotype.NewOpaqueType(goType)
opaque.PgTyp = pgt
return opaque, nil
}
// Known type.
var typ gotype.Type
var isKnownType bool
if nullable {
typ, isKnownType = gotype.FindKnownTypeNullable(pgt.OID())
} else {
typ, isKnownType = gotype.FindKnownTypeNonNullable(pgt.OID())
}
if isKnownType {
switch typ := typ.(type) {
case gotype.ArrayType:
typ.PgArray = pgt.(pg.ArrayType)
return typ, nil
case gotype.CompositeType:
typ.PgComposite = pgt.(pg.CompositeType)
return typ, nil
case gotype.EnumType:
typ.PgEnum = pgt.(pg.EnumType)
return typ, nil
case gotype.OpaqueType:
typ.PgTyp = pgt
return typ, nil
case gotype.VoidType:
return gotype.VoidType{}, nil
default:
return nil, fmt.Errorf("resolve unhandled known postgres type %T", pgt)
}
}
// New type that pggen will define in generated source code.
switch pgt := pgt.(type) {
case pg.ArrayType:
elemType, err := tr.Resolve(pgt.ElemType, nullable, pkgPath)
if err != nil {
return nil, fmt.Errorf("resolve array elem type for array type %q: %w", pgt.Name, err)
}
return gotype.NewArrayType(pkgPath, pgt, tr.caser, elemType), nil
case pg.EnumType:
enum := gotype.NewEnumType(pkgPath, pgt, tr.caser)
return enum, nil
case pg.CompositeType:
comp, err := CreateCompositeType(pkgPath, pgt, tr, tr.caser)
if err != nil {
return nil, fmt.Errorf("create composite type: %w", err)
}
return comp, nil
}
return nil, fmt.Errorf("no go type found for Postgres type %s oid=%d", pgt.String(), pgt.OID())
}
// CreateCompositeType creates a struct to represent a Postgres composite type.
// The type is rooted under pkgPath.
func CreateCompositeType(
pkgPath string,
pgt pg.CompositeType,
resolver TypeResolver,
caser casing.Caser,
) (gotype.CompositeType, error) {
name := caser.ToUpperGoIdent(pgt.Name)
if name == "" {
name = gotype.ChooseFallbackName(pgt.Name, "UnnamedStruct")
}
fieldNames := make([]string, len(pgt.ColumnNames))
fieldTypes := make([]gotype.Type, len(pgt.ColumnTypes))
for i, colName := range pgt.ColumnNames {
ident := caser.ToUpperGoIdent(colName)
if ident == "" {
ident = gotype.ChooseFallbackName(colName, "UnnamedField"+strconv.Itoa(i))
}
fieldNames[i] = ident
fieldType, err := resolver.Resolve(pgt.ColumnTypes[i] /*nullable*/, true, pkgPath)
if err != nil {
return gotype.CompositeType{}, fmt.Errorf("resolve composite column type %s.%s: %w", pgt.Name, colName, err)
}
fieldTypes[i] = fieldType
}
ct := gotype.CompositeType{
PgComposite: pgt,
PkgPath: pkgPath,
Pkg: gotype.ExtractShortPackage([]byte(pkgPath)),
Name: name,
FieldNames: fieldNames,
FieldTypes: fieldTypes,
}
return ct, nil
}
func listAliases(name string) []string {
if strings.HasPrefix(name, "_") {
aliases := listElemAliases(name[1:])
for i, alias := range aliases {
aliases[i] = "_" + alias
}
return aliases
}
return listElemAliases(name)
}
// listElemAliases lists all known type aliases for a type name. The requested
// type name is included in the list.
// https://www.postgresql.org/docs/13/datatype.html#DATATYPE-TABLE
func listElemAliases(name string) []string {
switch name {
case "bigint", "int8":
return []string{"bigint", "int8"}
case "bigserial", "serial8":
return []string{"bigserial", "serial8"}
case "bool", "boolean":
return []string{"bool", "boolean"}
case "float8", "double precision":
return []string{"float8", "double precision"}
case "int", "integer", "int4":
return []string{"int", "integer", "int4"}
case "real", "float4":
return []string{"real", "float4"}
case "smallint", "int2":
return []string{"smallint", "int2"}
case "smallserial", "serial2":
return []string{"smallserial", "serial2"}
case "serial", "serial4":
return []string{"serial", "serial4"}
default:
// TODO: numeric, multi word aliases
return []string{name}
}
}