/
root.go
148 lines (131 loc) · 3.38 KB
/
root.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
// Copyright (c) 2023 Blockwatch Data Inc.
// Authors
// - jean.schmitt@ubisoft.com
// - abdul@blockwatch.cc
package main
import (
"bytes"
"flag"
"fmt"
"io"
"net/http"
"net/url"
"os"
"blockwatch.cc/tzgo/internal/generate"
"blockwatch.cc/tzgo/internal/parse"
"github.com/iancoleman/strcase"
"github.com/pkg/errors"
"gopkg.in/yaml.v3"
)
var (
errExit = errors.New("exit")
endpointFlag string
addressFlag string
srcFlag string
nameFlag string
pkgFlag string
outFlag string
fixupFileFlag string
)
func init() {
flag.StringVar(&endpointFlag, "endpoint", "https://rpc.tzstats.com", "rpc endpoint")
flag.StringVar(&addressFlag, "address", "", "address of the contract. required if -src is not set")
flag.StringVar(&srcFlag, "src", "", "json file containing the contracts's script")
flag.StringVar(&nameFlag, "name", "", "name of the contract")
flag.StringVar(&pkgFlag, "pkg", "", "package name of the output go code")
flag.StringVar(&outFlag, "out", "", "output file. Prints to Stdout if not set")
flag.StringVar(&fixupFileFlag, "fixup", "", "yaml file to fix generated go code for automatically generated functions/variable names")
}
func parseFlags() error {
if len(os.Args) >= 2 {
switch os.Args[1] {
case "version":
printVersion()
return errExit
case "help":
fmt.Printf("Usage: %s [flags]\n", appName)
fmt.Println("\nFlags")
flag.PrintDefaults()
}
}
flag.Parse()
return nil
}
func runCommand() error {
if pkgFlag == "" {
return errors.New("-pkg is required, to get package name")
}
if nameFlag == "" {
return errors.New("-name is required to set name of contract")
}
src, err := getSrc()
if err != nil {
return errors.Wrap(err, "failed to get contract script")
}
generated, err := generateBindings(src)
if err != nil {
return errors.Wrap(err, "failed to generate bindings")
}
err = writeResult(generated)
if err != nil {
return errors.Wrap(err, "failed to write generated code to file")
}
return nil
}
func generateBindings(script []byte) ([]byte, error) {
var err error
data := generate.Data{
Address: addressFlag,
Package: pkgFlag,
}
data.Contract, data.Structs, err = parse.Parse(script, nameFlag)
if err != nil {
return nil, err
}
if fixupFileFlag != "" {
fixupFile, err := os.ReadFile(fixupFileFlag)
if err != nil {
return nil, err
}
var fixupCfg parse.FixupConfig
err = yaml.NewDecoder(bytes.NewReader(fixupFile)).Decode(&fixupCfg)
if err != nil {
return nil, err
}
data.Structs = parse.Fixup(fixupCfg, data.Structs, strcase.ToCamel)
}
return generate.Render(&data)
}
func getSrc() ([]byte, error) {
if srcFlag != "" {
return os.ReadFile(srcFlag)
}
// Get source from RPC
// At this point, addressFlag is required
if addressFlag == "" {
return nil, errors.New("-address is required when getting script from rpc")
}
u, err := url.JoinPath(endpointFlag, "chains/main/blocks/head/context/contracts", addressFlag, "script")
if err != nil {
return nil, err
}
res, err := http.Get(u)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, errors.Errorf("failed to get contract script at url %s: %v", u, res.Status)
}
return io.ReadAll(res.Body)
}
func writeResult(out []byte) error {
if outFlag == "" {
_, err := os.Stdout.Write(out)
if err != nil {
return err
}
return nil
}
return os.WriteFile(outFlag, out, 0o644)
}