Skip to content

Commit

Permalink
Correctly use context in plugin and provide alternative _WithContext …
Browse files Browse the repository at this point in the history
…methods (#62)
  • Loading branch information
Marton6 committed Mar 12, 2024
1 parent a1a2815 commit 9101916
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 40 deletions.
50 changes: 34 additions & 16 deletions extism.go
Expand Up @@ -28,7 +28,6 @@ type Runtime struct {
Wazero wazero.Runtime
Extism api.Module
Env api.Module
ctx context.Context
hasWasi bool
}

Expand Down Expand Up @@ -302,7 +301,12 @@ func (m *Manifest) UnmarshalJSON(data []byte) error {

// Close closes the plugin by freeing the underlying resources.
func (p *Plugin) Close() error {
return p.Runtime.Wazero.Close(p.Runtime.ctx)
return p.CloseWithContext(context.Background())
}

// CloseWithContext closes the plugin by freeing the underlying resources.
func (p *Plugin) CloseWithContext(ctx context.Context) error {
return p.Runtime.Wazero.Close(ctx)
}

// NewPlugin creates a new Extism plugin with the given manifest, configuration, and host functions.
Expand Down Expand Up @@ -351,17 +355,16 @@ func NewPlugin(
Wazero: rt,
Extism: extism,
Env: env,
ctx: ctx,
}

if config.EnableWasi {
wasi_snapshot_preview1.MustInstantiate(c.ctx, c.Wazero)
wasi_snapshot_preview1.MustInstantiate(ctx, c.Wazero)

c.hasWasi = true
}

for name, funcs := range hostModules {
_, err := buildHostModule(c.ctx, c.Wazero, name, funcs)
_, err := buildHostModule(ctx, c.Wazero, name, funcs)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -429,7 +432,7 @@ func NewPlugin(
}
}

m, err := c.Wazero.InstantiateWithConfig(c.ctx, data.Data, moduleConfig.WithName(data.Name))
m, err := c.Wazero.InstantiateWithConfig(ctx, data.Data, moduleConfig.WithName(data.Name))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -470,7 +473,7 @@ func NewPlugin(
logLevel: logLevel,
}

p.guestRuntime = detectGuestRuntime(p)
p.guestRuntime = detectGuestRuntime(ctx, p)
return p, nil
}

Expand All @@ -482,29 +485,39 @@ func NewPlugin(

// SetInput sets the input data for the plugin to be used in the next WebAssembly function call.
func (plugin *Plugin) SetInput(data []byte) (uint64, error) {
_, err := plugin.Runtime.Extism.ExportedFunction("reset").Call(plugin.Runtime.ctx)
return plugin.SetInputWithContext(context.Background(), data)
}

// SetInputWithContext sets the input data for the plugin to be used in the next WebAssembly function call.
func (plugin *Plugin) SetInputWithContext(ctx context.Context, data []byte) (uint64, error) {
_, err := plugin.Runtime.Extism.ExportedFunction("reset").Call(ctx)
if err != nil {
fmt.Println(err)
return 0, errors.New("reset")
}

ptr, err := plugin.Runtime.Extism.ExportedFunction("alloc").Call(plugin.Runtime.ctx, uint64(len(data)))
ptr, err := plugin.Runtime.Extism.ExportedFunction("alloc").Call(ctx, uint64(len(data)))
if err != nil {
return 0, err
}
plugin.Memory().Write(uint32(ptr[0]), data)
plugin.Runtime.Extism.ExportedFunction("input_set").Call(plugin.Runtime.ctx, ptr[0], uint64(len(data)))
plugin.Runtime.Extism.ExportedFunction("input_set").Call(ctx, ptr[0], uint64(len(data)))
return ptr[0], nil
}

// GetOutput retrieves the output data from the last WebAssembly function call.
func (plugin *Plugin) GetOutput() ([]byte, error) {
outputOffs, err := plugin.Runtime.Extism.ExportedFunction("output_offset").Call(plugin.Runtime.ctx)
return plugin.GetOutputWithContext(context.Background())
}

// GetOutputWithContext retrieves the output data from the last WebAssembly function call.
func (plugin *Plugin) GetOutputWithContext(ctx context.Context) ([]byte, error) {
outputOffs, err := plugin.Runtime.Extism.ExportedFunction("output_offset").Call(ctx)
if err != nil {
return []byte{}, err
}

outputLen, err := plugin.Runtime.Extism.ExportedFunction("output_length").Call(plugin.Runtime.ctx)
outputLen, err := plugin.Runtime.Extism.ExportedFunction("output_length").Call(ctx)
if err != nil {
return []byte{}, err
}
Expand All @@ -524,7 +537,12 @@ func (plugin *Plugin) Memory() api.Memory {

// GetError retrieves the error message from the last WebAssembly function call, if any.
func (plugin *Plugin) GetError() string {
errOffs, err := plugin.Runtime.Extism.ExportedFunction("error_get").Call(plugin.Runtime.ctx)
return plugin.GetErrorWithContext(context.Background())
}

// GetErrorWithContext retrieves the error message from the last WebAssembly function call.
func (plugin *Plugin) GetErrorWithContext(ctx context.Context) string {
errOffs, err := plugin.Runtime.Extism.ExportedFunction("error_get").Call(ctx)
if err != nil {
return ""
}
Expand All @@ -533,7 +551,7 @@ func (plugin *Plugin) GetError() string {
return ""
}

errLen, err := plugin.Runtime.Extism.ExportedFunction("length").Call(plugin.Runtime.ctx, errOffs[0])
errLen, err := plugin.Runtime.Extism.ExportedFunction("length").Call(ctx, errOffs[0])
if err != nil {
return ""
}
Expand All @@ -549,7 +567,7 @@ func (plugin *Plugin) FunctionExists(name string) bool {

// Call a function by name with the given input, returning the output
func (plugin *Plugin) Call(name string, data []byte) (uint32, []byte, error) {
return plugin.CallWithContext(plugin.Runtime.ctx, name, data)
return plugin.CallWithContext(context.Background(), name, data)
}

// Call a function by name with the given input and context, returning the output
Expand Down Expand Up @@ -579,7 +597,7 @@ func (plugin *Plugin) CallWithContext(ctx context.Context, name string, data []b

var isStart = name == "_start"
if plugin.guestRuntime.init != nil && !isStart && !plugin.guestRuntime.initialized {
err := plugin.guestRuntime.init()
err := plugin.guestRuntime.init(ctx)
if err != nil {
return 1, []byte{}, errors.New(fmt.Sprintf("failed to initialize runtime: %v", err))
}
Expand Down
71 changes: 69 additions & 2 deletions extism_test.go
Expand Up @@ -13,6 +13,8 @@ import (

"github.com/stretchr/testify/assert"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/experimental"
"github.com/tetratelabs/wazero/experimental/logging"
"github.com/tetratelabs/wazero/sys"
)

Expand Down Expand Up @@ -518,7 +520,7 @@ func TestCancel(t *testing.T) {
manifest := manifest("sleep.wasm")
manifest.Config["duration"] = "3" // sleep for 3 seconds

ctx, cancel := context.WithCancel(context.Background())
ctx := context.Background()
config := PluginConfig{
ModuleConfig: wazero.NewModuleConfig().WithSysWalltime(),
EnableWasi: true,
Expand All @@ -533,12 +535,13 @@ func TestCancel(t *testing.T) {

defer plugin.Close()

ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()

exit, _, err := plugin.Call("run_test", []byte{})
exit, _, err := plugin.CallWithContext(ctx, "run_test", []byte{})

assert.Equal(t, sys.ExitCodeContextCanceled, exit, "Exit code must be `sys.ExitCodeContextCanceled`")
assert.Equal(t, "module closed with context canceled", err.Error())
Expand Down Expand Up @@ -734,6 +737,70 @@ func TestInputOffset(t *testing.T) {
}
}

// make sure cancelling the context given to NewPlugin doesn't affect plugin calls
func TestContextCancel(t *testing.T) {
manifest := manifest("sleep.wasm")
manifest.Config["duration"] = "0" // sleep for 0 seconds

ctx, cancel := context.WithCancel(context.Background())
config := PluginConfig{
ModuleConfig: wazero.NewModuleConfig().WithSysWalltime(),
EnableWasi: true,
RuntimeConfig: wazero.NewRuntimeConfig().WithCloseOnContextDone(true),
}

plugin, err := NewPlugin(ctx, manifest, config, []HostFunction{})

if err != nil {
t.Errorf("Could not create plugin: %v", err)
}

defer plugin.Close()
cancel() // cancel the parent context

exit, out, err := plugin.CallWithContext(context.Background(), "run_test", []byte{})

if assertCall(t, err, exit) {
assert.Equal(t, "slept for 0 seconds", string(out))
}
}

// make sure we can still turn on experimental wazero features
func TestEnableExperimentalFeature(t *testing.T) {
var buf bytes.Buffer

// Set context to one that has an experimental listener
ctx := context.WithValue(context.Background(), experimental.FunctionListenerFactoryKey{}, logging.NewLoggingListenerFactory(&buf))

manifest := manifest("sleep.wasm")
manifest.Config["duration"] = "0" // sleep for 0 seconds

config := PluginConfig{
ModuleConfig: wazero.NewModuleConfig().WithSysWalltime(),
EnableWasi: true,
RuntimeConfig: wazero.NewRuntimeConfig().WithCloseOnContextDone(true),
}

plugin, err := NewPlugin(ctx, manifest, config, []HostFunction{})

if err != nil {
t.Errorf("Could not create plugin: %v", err)
}

defer plugin.Close()

var buf2 bytes.Buffer
ctx = context.WithValue(context.Background(), experimental.FunctionListenerFactoryKey{}, logging.NewLoggingListenerFactory(&buf2))
exit, out, err := plugin.CallWithContext(ctx, "run_test", []byte{})

if assertCall(t, err, exit) {
assert.Equal(t, "slept for 0 seconds", string(out))

assert.NotEmpty(t, buf.String())
assert.Empty(t, buf2.String())
}
}

func BenchmarkInitialize(b *testing.B) {
ctx := context.Background()
cache := wazero.NewCompilationCache()
Expand Down
21 changes: 18 additions & 3 deletions host.go
Expand Up @@ -123,7 +123,12 @@ func (p *CurrentPlugin) Memory() api.Memory {

// Alloc a new memory block of the given length, returning its offset
func (p *CurrentPlugin) Alloc(n uint64) (uint64, error) {
out, err := p.plugin.Runtime.Extism.ExportedFunction("alloc").Call(p.plugin.Runtime.ctx, uint64(n))
return p.AllocWithContext(context.Background(), n)
}

// Alloc a new memory block of the given length, returning its offset
func (p *CurrentPlugin) AllocWithContext(ctx context.Context, n uint64) (uint64, error) {
out, err := p.plugin.Runtime.Extism.ExportedFunction("alloc").Call(ctx, uint64(n))
if err != nil {
return 0, err
} else if len(out) != 1 {
Expand All @@ -135,7 +140,12 @@ func (p *CurrentPlugin) Alloc(n uint64) (uint64, error) {

// Free the memory block specified by the given offset
func (p *CurrentPlugin) Free(offset uint64) error {
_, err := p.plugin.Runtime.Extism.ExportedFunction("free").Call(p.plugin.Runtime.ctx, uint64(offset))
return p.FreeWithContext(context.Background(), offset)
}

// Free the memory block specified by the given offset
func (p *CurrentPlugin) FreeWithContext(ctx context.Context, offset uint64) error {
_, err := p.plugin.Runtime.Extism.ExportedFunction("free").Call(ctx, uint64(offset))
if err != nil {
return err
}
Expand All @@ -145,7 +155,12 @@ func (p *CurrentPlugin) Free(offset uint64) error {

// Length returns the number of bytes allocated at the specified offset
func (p *CurrentPlugin) Length(offs uint64) (uint64, error) {
out, err := p.plugin.Runtime.Extism.ExportedFunction("length").Call(p.plugin.Runtime.ctx, uint64(offs))
return p.LengthWithContext(context.Background(), offs)
}

// Length returns the number of bytes allocated at the specified offset
func (p *CurrentPlugin) LengthWithContext(ctx context.Context, offs uint64) (uint64, error) {
out, err := p.plugin.Runtime.Extism.ExportedFunction("length").Call(ctx, uint64(offs))
if err != nil {
return 0, err
} else if len(out) != 1 {
Expand Down

0 comments on commit 9101916

Please sign in to comment.