/
ast.go
92 lines (75 loc) · 1.8 KB
/
ast.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
package ast
import (
"bytes"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"log"
"strings"
"github.com/fatih/astrewrite"
)
const (
DebugMode string = "DEV"
)
func RenamePackage(packageName string) func(file *ast.File) *ast.File {
return func(file *ast.File) *ast.File {
file.Name = &ast.Ident{Name: packageName}
return file
}
}
func ChangeType(typeName string, newType string, debugMode string) func(file *ast.File) *ast.File {
return func(file *ast.File) *ast.File {
rewriteFunc := func(n ast.Node) (ast.Node, bool) {
switch x := n.(type) {
case *ast.Ident:
if typeName == x.Name {
x = &ast.Ident{Name: newType}
}
return x, true
case *ast.CallExpr:
for i := 0; i < len(x.Args); i++ {
v, ok := x.Args[i].(*ast.Ident)
if ok {
if strings.EqualFold(typeName, v.Name) {
x.Args[i] = &ast.Ident{Name: fmt.Sprintf("%s.(%s)", v.Name, newType)}
}
}
}
return x, true
default:
if debugMode == DebugMode {
log.Println("ast node:")
log.Println(fmt.Sprintf("verbose value: %#v", x))
log.Println(fmt.Sprintf("type: %T", x))
log.Println(fmt.Sprintf("value: %v", x))
}
}
return n, true
}
astrewrite.Walk(file, rewriteFunc)
return file
}
}
func ModifyAst(dest []byte, fns ...func(*ast.File) *ast.File) ([]byte, error) {
destFset := token.NewFileSet()
destF, err := parser.ParseFile(destFset, "", dest, 0)
if err != nil {
return nil, err
}
for _, fn := range fns {
destF = fn(destF)
}
var buf bytes.Buffer
if err := format.Node(&buf, destFset, destF); err != nil {
return nil, &BadFormattedCode{Err: err}
}
return buf.Bytes(), nil
}
type BadFormattedCode struct {
Err error
}
func (e BadFormattedCode) Error() string {
return fmt.Sprintf("couldn't format package code (%v)", e.Err)
}