/
registry.go
108 lines (85 loc) · 2.13 KB
/
registry.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
package types
import (
"fmt"
"go/types"
"strings"
"github.com/donatorsky/go-cmder/internal/utils"
"golang.org/x/tools/go/packages"
)
func NewRegistry(pkg *packages.Package) *Registry {
r := Registry{
selfPkg: pkg.PkgPath,
types: map[string]*Type{},
imports: utils.NewUniqueSlice[*Type](),
}
for _, p := range pkg.Imports {
r.types[p.PkgPath] = &Type{
Alias: nil,
Name: p.Name,
Path: p.PkgPath,
}
}
for _, syntax := range pkg.Syntax {
for _, importSpec := range syntax.Imports {
if importSpec.Name == nil || importSpec.Name.Name == "." || importSpec.Name.Name == "_" {
continue
}
t, ok := r.types[strings.Trim(importSpec.Path.Value, `"`)]
if !ok {
continue
}
t.Alias = &importSpec.Name.Name
}
}
return &r
}
type Registry struct {
selfPkg string
types map[string]*Type
imports *utils.UniqueSlice[*Type]
}
func (r *Registry) Imports() []*Type {
return r.imports.Items()
}
func (r *Registry) Resolve(fieldType types.Type) (pointer string, unwrappedType string, _ error) {
for {
pointerType, ok := fieldType.(*types.Pointer)
if !ok {
break
}
pointer += "*"
fieldType = pointerType.Elem()
}
switch actualType := fieldType.(type) {
case *types.Named,
*types.Struct,
*types.Signature:
typeFQN := actualType.String()
for name, t := range r.types {
if !strings.Contains(typeFQN, fmt.Sprintf("%s.", name)) {
continue
}
alias := t.Name
if t.Alias != nil {
alias = *t.Alias
}
typeFQN = strings.ReplaceAll(typeFQN, fmt.Sprintf("%s.", name), fmt.Sprintf("%s.", alias))
_, _ = r.imports.Append(t)
}
return pointer, strings.ReplaceAll(typeFQN, fmt.Sprintf("%s.", r.selfPkg), ""), nil
case *types.Slice:
elemPointer, elemType, err := r.Resolve(actualType.Elem())
if err != nil {
return "", "", err
}
return pointer, fmt.Sprintf("[]%s%s", elemPointer, elemType), nil
case *types.Array:
elemPointer, elemType, err := r.Resolve(actualType.Elem())
if err != nil {
return "", "", err
}
return pointer, fmt.Sprintf("[%d]%s%s", actualType.Len(), elemPointer, elemType), nil
default:
return pointer, fieldType.String(), nil
}
}