/
parser.go
127 lines (100 loc) · 2.78 KB
/
parser.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package parser
import (
"fmt"
"github.com/butzopower/natsu/core"
"github.com/butzopower/natsu/util"
"go/types"
"golang.org/x/tools/go/packages"
"os"
"strings"
)
func Parse(unionPackageName, unionTypeName string) (core.UnionDetails, error) {
var result core.UnionDetails
obj, err := findType(unionPackageName, unionTypeName)
if err != nil {
return result, err
}
unionType, err := extractUnion(obj)
if err != nil {
return result, err
}
var terms []*types.Term
for i := 0; i < unionType.Len(); i++ {
terms = append(terms, unionType.Term(i))
}
termNames, err := util.MapWithErr(terms, func(term *types.Term) (core.TermPath, error) {
pointer, isPointer := term.Type().Underlying().(*types.Pointer)
full := term.String()
if isPointer {
full = pointer.Elem().String()
}
pkg, local, splitErr := splitSourceType(full)
if splitErr != nil {
return core.TermPath{}, splitErr
}
return core.TermPath{
Package: pkg,
Local: local,
Pointer: isPointer,
}, nil
})
if err != nil {
return result, err
}
return core.UnionDetails{
Path: unionPackageName,
Union: core.TermPath{
Package: unionPackageName,
Local: unionTypeName,
},
Terms: termNames,
}, nil
}
func findType(sourceTypePackage, sourceTypeName string) (types.Object, error) {
pkg, err := loadPackage(sourceTypePackage)
if err != nil {
return nil, err
}
obj := pkg.Types.Scope().Lookup(sourceTypeName)
if obj == nil {
return nil, fmt.Errorf("%s not found in declared types of %s", sourceTypeName, pkg)
}
return obj, nil
}
func extractUnion(obj types.Object) (*types.Union, error) {
if _, ok := obj.(*types.TypeName); !ok {
return nil, fmt.Errorf("%v is not a named type", obj)
}
interfaceType, ok := obj.Type().Underlying().(*types.Interface)
if !ok {
return nil, fmt.Errorf("type %v is not an interface", obj)
}
if interfaceType.NumEmbeddeds() == 0 {
return nil, fmt.Errorf("type %v does not contain embedded types", obj)
}
unionType, ok := interfaceType.EmbeddedType(0).(*types.Union)
if !ok {
return nil, fmt.Errorf("type %v is not a union", obj)
}
return unionType, nil
}
func loadPackage(path string) (*packages.Package, error) {
cfg := &packages.Config{Mode: packages.NeedTypes | packages.NeedImports | packages.NeedTypesInfo}
pkgs, err := packages.Load(cfg, path)
if err != nil {
return nil, fmt.Errorf("loading packages for inspection: %v", err)
}
if packages.PrintErrors(pkgs) > 0 {
os.Exit(1)
}
return pkgs[0], nil
}
func splitSourceType(sourceType string) (string, string, error) {
idx := strings.LastIndexByte(sourceType, '.')
if idx == -1 {
return "", "", fmt.Errorf(`expected qualified type as "pkg/path.MyType"`)
}
sourceTypePackage := sourceType[0:idx]
sourceTypeName := sourceType[idx+1:]
return sourceTypePackage, sourceTypeName, nil
}