diff --git a/extism.go b/extism.go index 03c57e1..9364786 100644 --- a/extism.go +++ b/extism.go @@ -28,7 +28,6 @@ type Runtime struct { Wazero wazero.Runtime Extism api.Module Env api.Module - ctx context.Context hasWasi bool } @@ -302,7 +301,12 @@ func (m *Manifest) UnmarshalJSON(data []byte) error { // Close closes the plugin by freeing the underlying resources. func (p *Plugin) Close() error { - return p.Runtime.Wazero.Close(p.Runtime.ctx) + return p.CloseWithContext(context.Background()) +} + +// CloseWithContext closes the plugin by freeing the underlying resources. +func (p *Plugin) CloseWithContext(ctx context.Context) error { + return p.Runtime.Wazero.Close(ctx) } // NewPlugin creates a new Extism plugin with the given manifest, configuration, and host functions. @@ -351,17 +355,16 @@ func NewPlugin( Wazero: rt, Extism: extism, Env: env, - ctx: ctx, } if config.EnableWasi { - wasi_snapshot_preview1.MustInstantiate(c.ctx, c.Wazero) + wasi_snapshot_preview1.MustInstantiate(ctx, c.Wazero) c.hasWasi = true } for name, funcs := range hostModules { - _, err := buildHostModule(c.ctx, c.Wazero, name, funcs) + _, err := buildHostModule(ctx, c.Wazero, name, funcs) if err != nil { return nil, err } @@ -429,7 +432,7 @@ func NewPlugin( } } - m, err := c.Wazero.InstantiateWithConfig(c.ctx, data.Data, moduleConfig.WithName(data.Name)) + m, err := c.Wazero.InstantiateWithConfig(ctx, data.Data, moduleConfig.WithName(data.Name)) if err != nil { return nil, err } @@ -470,7 +473,7 @@ func NewPlugin( logLevel: logLevel, } - p.guestRuntime = detectGuestRuntime(p) + p.guestRuntime = detectGuestRuntime(ctx, p) return p, nil } @@ -482,29 +485,39 @@ func NewPlugin( // SetInput sets the input data for the plugin to be used in the next WebAssembly function call. func (plugin *Plugin) SetInput(data []byte) (uint64, error) { - _, err := plugin.Runtime.Extism.ExportedFunction("reset").Call(plugin.Runtime.ctx) + return plugin.SetInputWithContext(context.Background(), data) +} + +// SetInputWithContext sets the input data for the plugin to be used in the next WebAssembly function call. +func (plugin *Plugin) SetInputWithContext(ctx context.Context, data []byte) (uint64, error) { + _, err := plugin.Runtime.Extism.ExportedFunction("reset").Call(ctx) if err != nil { fmt.Println(err) return 0, errors.New("reset") } - ptr, err := plugin.Runtime.Extism.ExportedFunction("alloc").Call(plugin.Runtime.ctx, uint64(len(data))) + ptr, err := plugin.Runtime.Extism.ExportedFunction("alloc").Call(ctx, uint64(len(data))) if err != nil { return 0, err } plugin.Memory().Write(uint32(ptr[0]), data) - plugin.Runtime.Extism.ExportedFunction("input_set").Call(plugin.Runtime.ctx, ptr[0], uint64(len(data))) + plugin.Runtime.Extism.ExportedFunction("input_set").Call(ctx, ptr[0], uint64(len(data))) return ptr[0], nil } // GetOutput retrieves the output data from the last WebAssembly function call. func (plugin *Plugin) GetOutput() ([]byte, error) { - outputOffs, err := plugin.Runtime.Extism.ExportedFunction("output_offset").Call(plugin.Runtime.ctx) + return plugin.GetOutputWithContext(context.Background()) +} + +// GetOutputWithContext retrieves the output data from the last WebAssembly function call. +func (plugin *Plugin) GetOutputWithContext(ctx context.Context) ([]byte, error) { + outputOffs, err := plugin.Runtime.Extism.ExportedFunction("output_offset").Call(ctx) if err != nil { return []byte{}, err } - outputLen, err := plugin.Runtime.Extism.ExportedFunction("output_length").Call(plugin.Runtime.ctx) + outputLen, err := plugin.Runtime.Extism.ExportedFunction("output_length").Call(ctx) if err != nil { return []byte{}, err } @@ -524,7 +537,12 @@ func (plugin *Plugin) Memory() api.Memory { // GetError retrieves the error message from the last WebAssembly function call, if any. func (plugin *Plugin) GetError() string { - errOffs, err := plugin.Runtime.Extism.ExportedFunction("error_get").Call(plugin.Runtime.ctx) + return plugin.GetErrorWithContext(context.Background()) +} + +// GetErrorWithContext retrieves the error message from the last WebAssembly function call. +func (plugin *Plugin) GetErrorWithContext(ctx context.Context) string { + errOffs, err := plugin.Runtime.Extism.ExportedFunction("error_get").Call(ctx) if err != nil { return "" } @@ -533,7 +551,7 @@ func (plugin *Plugin) GetError() string { return "" } - errLen, err := plugin.Runtime.Extism.ExportedFunction("length").Call(plugin.Runtime.ctx, errOffs[0]) + errLen, err := plugin.Runtime.Extism.ExportedFunction("length").Call(ctx, errOffs[0]) if err != nil { return "" } @@ -549,7 +567,7 @@ func (plugin *Plugin) FunctionExists(name string) bool { // Call a function by name with the given input, returning the output func (plugin *Plugin) Call(name string, data []byte) (uint32, []byte, error) { - return plugin.CallWithContext(plugin.Runtime.ctx, name, data) + return plugin.CallWithContext(context.Background(), name, data) } // Call a function by name with the given input and context, returning the output @@ -579,7 +597,7 @@ func (plugin *Plugin) CallWithContext(ctx context.Context, name string, data []b var isStart = name == "_start" if plugin.guestRuntime.init != nil && !isStart && !plugin.guestRuntime.initialized { - err := plugin.guestRuntime.init() + err := plugin.guestRuntime.init(ctx) if err != nil { return 1, []byte{}, errors.New(fmt.Sprintf("failed to initialize runtime: %v", err)) } diff --git a/extism_test.go b/extism_test.go index d2349fe..7fb128d 100644 --- a/extism_test.go +++ b/extism_test.go @@ -13,6 +13,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/experimental" + "github.com/tetratelabs/wazero/experimental/logging" "github.com/tetratelabs/wazero/sys" ) @@ -518,7 +520,7 @@ func TestCancel(t *testing.T) { manifest := manifest("sleep.wasm") manifest.Config["duration"] = "3" // sleep for 3 seconds - ctx, cancel := context.WithCancel(context.Background()) + ctx := context.Background() config := PluginConfig{ ModuleConfig: wazero.NewModuleConfig().WithSysWalltime(), EnableWasi: true, @@ -533,12 +535,13 @@ func TestCancel(t *testing.T) { defer plugin.Close() + ctx, cancel := context.WithCancel(context.Background()) go func() { time.Sleep(100 * time.Millisecond) cancel() }() - exit, _, err := plugin.Call("run_test", []byte{}) + exit, _, err := plugin.CallWithContext(ctx, "run_test", []byte{}) assert.Equal(t, sys.ExitCodeContextCanceled, exit, "Exit code must be `sys.ExitCodeContextCanceled`") assert.Equal(t, "module closed with context canceled", err.Error()) @@ -734,6 +737,70 @@ func TestInputOffset(t *testing.T) { } } +// make sure cancelling the context given to NewPlugin doesn't affect plugin calls +func TestContextCancel(t *testing.T) { + manifest := manifest("sleep.wasm") + manifest.Config["duration"] = "0" // sleep for 0 seconds + + ctx, cancel := context.WithCancel(context.Background()) + config := PluginConfig{ + ModuleConfig: wazero.NewModuleConfig().WithSysWalltime(), + EnableWasi: true, + RuntimeConfig: wazero.NewRuntimeConfig().WithCloseOnContextDone(true), + } + + plugin, err := NewPlugin(ctx, manifest, config, []HostFunction{}) + + if err != nil { + t.Errorf("Could not create plugin: %v", err) + } + + defer plugin.Close() + cancel() // cancel the parent context + + exit, out, err := plugin.CallWithContext(context.Background(), "run_test", []byte{}) + + if assertCall(t, err, exit) { + assert.Equal(t, "slept for 0 seconds", string(out)) + } +} + +// make sure we can still turn on experimental wazero features +func TestEnableExperimentalFeature(t *testing.T) { + var buf bytes.Buffer + + // Set context to one that has an experimental listener + ctx := context.WithValue(context.Background(), experimental.FunctionListenerFactoryKey{}, logging.NewLoggingListenerFactory(&buf)) + + manifest := manifest("sleep.wasm") + manifest.Config["duration"] = "0" // sleep for 0 seconds + + config := PluginConfig{ + ModuleConfig: wazero.NewModuleConfig().WithSysWalltime(), + EnableWasi: true, + RuntimeConfig: wazero.NewRuntimeConfig().WithCloseOnContextDone(true), + } + + plugin, err := NewPlugin(ctx, manifest, config, []HostFunction{}) + + if err != nil { + t.Errorf("Could not create plugin: %v", err) + } + + defer plugin.Close() + + var buf2 bytes.Buffer + ctx = context.WithValue(context.Background(), experimental.FunctionListenerFactoryKey{}, logging.NewLoggingListenerFactory(&buf2)) + exit, out, err := plugin.CallWithContext(ctx, "run_test", []byte{}) + + if assertCall(t, err, exit) { + assert.Equal(t, "slept for 0 seconds", string(out)) + + assert.NotEmpty(t, buf.String()) + assert.Empty(t, buf2.String()) + } +} + func BenchmarkInitialize(b *testing.B) { ctx := context.Background() cache := wazero.NewCompilationCache() diff --git a/host.go b/host.go index e28e444..05fec11 100644 --- a/host.go +++ b/host.go @@ -123,7 +123,12 @@ func (p *CurrentPlugin) Memory() api.Memory { // Alloc a new memory block of the given length, returning its offset func (p *CurrentPlugin) Alloc(n uint64) (uint64, error) { - out, err := p.plugin.Runtime.Extism.ExportedFunction("alloc").Call(p.plugin.Runtime.ctx, uint64(n)) + return p.AllocWithContext(context.Background(), n) +} + +// Alloc a new memory block of the given length, returning its offset +func (p *CurrentPlugin) AllocWithContext(ctx context.Context, n uint64) (uint64, error) { + out, err := p.plugin.Runtime.Extism.ExportedFunction("alloc").Call(ctx, uint64(n)) if err != nil { return 0, err } else if len(out) != 1 { @@ -135,7 +140,12 @@ func (p *CurrentPlugin) Alloc(n uint64) (uint64, error) { // Free the memory block specified by the given offset func (p *CurrentPlugin) Free(offset uint64) error { - _, err := p.plugin.Runtime.Extism.ExportedFunction("free").Call(p.plugin.Runtime.ctx, uint64(offset)) + return p.FreeWithContext(context.Background(), offset) +} + +// Free the memory block specified by the given offset +func (p *CurrentPlugin) FreeWithContext(ctx context.Context, offset uint64) error { + _, err := p.plugin.Runtime.Extism.ExportedFunction("free").Call(ctx, uint64(offset)) if err != nil { return err } @@ -145,7 +155,12 @@ func (p *CurrentPlugin) Free(offset uint64) error { // Length returns the number of bytes allocated at the specified offset func (p *CurrentPlugin) Length(offs uint64) (uint64, error) { - out, err := p.plugin.Runtime.Extism.ExportedFunction("length").Call(p.plugin.Runtime.ctx, uint64(offs)) + return p.LengthWithContext(context.Background(), offs) +} + +// Length returns the number of bytes allocated at the specified offset +func (p *CurrentPlugin) LengthWithContext(ctx context.Context, offs uint64) (uint64, error) { + out, err := p.plugin.Runtime.Extism.ExportedFunction("length").Call(ctx, uint64(offs)) if err != nil { return 0, err } else if len(out) != 1 { diff --git a/runtime.go b/runtime.go index c4b14c4..8013b93 100644 --- a/runtime.go +++ b/runtime.go @@ -1,6 +1,8 @@ package extism import ( + "context" + "github.com/tetratelabs/wazero/api" ) @@ -16,31 +18,31 @@ const ( type guestRuntime struct { runtimeType runtimeType - init func() error + init func(ctx context.Context) error initialized bool } -func detectGuestRuntime(p *Plugin) guestRuntime { +func detectGuestRuntime(ctx context.Context, p *Plugin) guestRuntime { m := p.Main - runtime, ok := haskellRuntime(p, m) + runtime, ok := haskellRuntime(ctx, p, m) if ok { return runtime } - runtime, ok = wasiRuntime(p, m) + runtime, ok = wasiRuntime(ctx, p, m) if ok { return runtime } p.Log(LogLevelTrace, "No runtime detected") - return guestRuntime{runtimeType: None, init: func() error { return nil }, initialized: true} + return guestRuntime{runtimeType: None, init: func(_ context.Context) error { return nil }, initialized: true} } // Check for Haskell runtime initialization functions // Initialize Haskell runtime if `hs_init` and `hs_exit` are present, // by calling the `hs_init` export -func haskellRuntime(p *Plugin, m api.Module) (guestRuntime, bool) { +func haskellRuntime(ctx context.Context, p *Plugin, m api.Module) (guestRuntime, bool) { initFunc := m.ExportedFunction("hs_init") if initFunc == nil { return guestRuntime{}, false @@ -54,14 +56,14 @@ func haskellRuntime(p *Plugin, m api.Module) (guestRuntime, bool) { reactorInit := m.ExportedFunction("_initialize") - init := func() error { + init := func(ctx context.Context) error { if reactorInit != nil { - _, err := reactorInit.Call(p.Runtime.ctx) + _, err := reactorInit.Call(ctx) if err != nil { p.Logf(LogLevelError, "Error running reactor _initialize: %s", err.Error()) } } - _, err := initFunc.Call(p.Runtime.ctx, 0, 0) + _, err := initFunc.Call(ctx, 0, 0) if err == nil { p.Log(LogLevelDebug, "Initialized Haskell language runtime.") } @@ -74,7 +76,7 @@ func haskellRuntime(p *Plugin, m api.Module) (guestRuntime, bool) { } // Check for initialization functions defined by the WASI standard -func wasiRuntime(p *Plugin, m api.Module) (guestRuntime, bool) { +func wasiRuntime(ctx context.Context, p *Plugin, m api.Module) (guestRuntime, bool) { if !p.Runtime.hasWasi { return guestRuntime{}, false } @@ -82,16 +84,16 @@ func wasiRuntime(p *Plugin, m api.Module) (guestRuntime, bool) { // WASI supports two modules: Reactors and Commands // we prioritize Reactors over Commands // see: https://github.com/WebAssembly/WASI/blob/main/legacy/application-abi.md - if r, ok := reactorModule(m, p); ok { + if r, ok := reactorModule(ctx, m, p); ok { return r, ok } - return commandModule(m, p) + return commandModule(ctx, m, p) } // Check for `_initialize` this is used by WASI to initialize certain interfaces. -func reactorModule(m api.Module, p *Plugin) (guestRuntime, bool) { - init := findFunc(m, p, "_initialize") +func reactorModule(ctx context.Context, m api.Module, p *Plugin) (guestRuntime, bool) { + init := findFunc(ctx, m, p, "_initialize") if init == nil { return guestRuntime{}, false } @@ -104,8 +106,8 @@ func reactorModule(m api.Module, p *Plugin) (guestRuntime, bool) { // Check for `__wasm__call_ctors`, this is used by WASI to // initialize certain interfaces. -func commandModule(m api.Module, p *Plugin) (guestRuntime, bool) { - init := findFunc(m, p, "__wasm_call_ctors") +func commandModule(ctx context.Context, m api.Module, p *Plugin) (guestRuntime, bool) { + init := findFunc(ctx, m, p, "__wasm_call_ctors") if init == nil { return guestRuntime{}, false } @@ -116,7 +118,7 @@ func commandModule(m api.Module, p *Plugin) (guestRuntime, bool) { return guestRuntime{runtimeType: Wasi, init: init}, true } -func findFunc(m api.Module, p *Plugin, name string) func() error { +func findFunc(ctx context.Context, m api.Module, p *Plugin, name string) func(context.Context) error { initFunc := m.ExportedFunction(name) if initFunc == nil { return nil @@ -128,9 +130,9 @@ func findFunc(m api.Module, p *Plugin, name string) func() error { return nil } - return func() error { + return func(ctx context.Context) error { p.Logf(LogLevelDebug, "Calling %v", name) - _, err := initFunc.Call(p.Runtime.ctx) + _, err := initFunc.Call(ctx) return err } }