-
Notifications
You must be signed in to change notification settings - Fork 1
/
plugin.go
139 lines (117 loc) · 3.84 KB
/
plugin.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package gen
import (
"fmt"
"strings"
"google.golang.org/protobuf/compiler/protogen"
)
func (gg *Generator) generatePluginFile(f *fileInfo) {
// This file will be imported by plugins written in TinyGo
filename := f.GeneratedFilenamePrefix + "_plugin.pb.go"
g := gg.plugin.NewGeneratedFile(filename, f.GoImportPath)
if len(f.pluginServices) == 0 && f.hostService == nil {
g.Skip()
}
// Build constraints
g.P("//go:build tinygo.wasm")
// Generate header
gg.generateHeader(g, f)
// Generate exported functions that wrap interfaces
for _, service := range f.pluginServices {
genPlugin(g, f, service)
}
genHostFunctions(g, f)
}
func genPlugin(g *protogen.GeneratedFile, f *fileInfo, service *serviceInfo) {
serviceVar := strings.ToLower(service.GoName[:1]) + service.GoName[1:]
// API version
g.P("const ", service.GoName, "PluginAPIVersion = ", service.Version)
g.P(fmt.Sprintf(`
//export %s_api_version
func _%s_api_version() uint64 {
return %sPluginAPIVersion
}`,
toSnakeCase(service.GoName), toSnakeCase(service.GoName), service.GoName,
))
// Variable definition
g.P("var ", serviceVar, " ", service.GoName)
// Register function
g.P("func Register", service.GoName, "(p ", service.GoName, ") {")
g.P(serviceVar, "= p")
g.P("}")
// Exported functions
for _, method := range service.Methods {
exportedName := toSnakeCase(service.GoName + method.GoName)
g.P("//export ", exportedName)
g.P("func _", exportedName, "(ptr, size uint32) uint64 {")
g.P("b := ", g.QualifiedGoIdent(pluginWasmPackage.Ident("PtrToByte")), "(ptr, size)")
g.P("req := new(", g.QualifiedGoIdent(method.Input.GoIdent), ")")
g.P(`if err := req.UnmarshalVT(b); err != nil {
return 0
}`)
g.P(fmt.Sprintf(`response, err := %s.%s(%s(), req)`,
serviceVar, method.GoName, g.QualifiedGoIdent(contextPackage.Ident("Background"))))
g.P(fmt.Sprintf(`if err != nil {
ptr, size = %s([]byte(err.Error()))
return (uint64(ptr) << uint64(32)) | uint64(size) |
// Indicate that this is the error string by setting the 32-th bit, assuming that
// no data exceeds 31-bit size (2 GiB).
%s
}
b, err = response.MarshalVT()
if err != nil {
return 0
}
ptr, size = %s(b)
return (uint64(ptr) << uint64(32)) | uint64(size)`,
g.QualifiedGoIdent(pluginWasmPackage.Ident("ByteToPtr")),
ErrorMaskBit,
g.QualifiedGoIdent(pluginWasmPackage.Ident("ByteToPtr"))))
g.P("}")
}
}
func genHostFunctions(g *protogen.GeneratedFile, f *fileInfo) {
if f.hostService == nil {
return
}
g.Import(unsafePackage)
// Host functions
structName := strings.ToLower(f.hostService.GoName[:1]) + f.hostService.GoName[1:]
g.P("type ", structName, " struct{}")
g.P()
g.P("func New", f.hostService.GoName, "()", f.hostService.GoName, "{")
g.P(" return ", structName, "{}")
g.P("}")
for _, method := range f.hostService.Methods {
importedName := toSnakeCase(method.GoName)
g.P(fmt.Sprintf(`
//go:wasmimport %s %s
func _%s(ptr uint32, size uint32) uint64
func (h %s) %s(ctx %s, request *%s) (*%s, error) {
buf, err := request.MarshalVT()
if err != nil {
return nil, err
}
ptr, size := %s(buf)
ptrSize := _%s(ptr, size)
%s(ptr)
ptr = uint32(ptrSize >> 32)
size = uint32(ptrSize)
buf = %s(ptr, size)
response := new(%s)
if err = response.UnmarshalVT(buf); err != nil {
return nil, err
}
return response, nil
}`,
f.hostService.Module, importedName, importedName, structName, method.GoName,
g.QualifiedGoIdent(contextPackage.Ident("Context")),
g.QualifiedGoIdent(method.Input.GoIdent),
g.QualifiedGoIdent(method.Output.GoIdent),
g.QualifiedGoIdent(pluginWasmPackage.Ident("ByteToPtr")),
importedName,
g.QualifiedGoIdent(pluginWasmPackage.Ident("FreePtr")),
g.QualifiedGoIdent(pluginWasmPackage.Ident("PtrToByte")),
g.QualifiedGoIdent(method.Output.GoIdent),
))
}
}