Skip to content

Commit

Permalink
chore: add context to UsageHandler methods (#115)
Browse files Browse the repository at this point in the history
Because

- We want the pipeline trigger context to be propagated to the component
execution and usage handler.

This commit

- Adds a context param to `Execute`, `Check` and `Collect`.
  • Loading branch information
jvallesm committed May 10, 2024
1 parent ed2068f commit 80c4780
Show file tree
Hide file tree
Showing 34 changed files with 143 additions and 98 deletions.
9 changes: 5 additions & 4 deletions .github/CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ the `hello` component will only need to override the following methods:
will return an object that implements the `Execute` method.
- `ExecutionWrapper` will wrap the execution call with the input and output
schema validation.
- `Execute([]*structpb.Struct) ([]*structpb.Struct, error)` is the most
- `Execute(context.Context []*structpb.Struct) ([]*structpb.Struct, error)` is the most
important function in the component. All the data manipulation will take place
here.

Expand Down Expand Up @@ -463,7 +463,7 @@ func (o *operator) CreateExecution(sysVars map[string]any, task string) (*base.E
return &base.ExecutionWrapper{Execution: e}, nil
}

func (e *execution) Execute(_ []*structpb.Struct) ([]*structpb.Struct, error) {
func (e *execution) Execute(context.Context, []*structpb.Struct) ([]*structpb.Struct, error) {
return nil, nil
}
```
Expand Down Expand Up @@ -498,7 +498,7 @@ func (o *operator) CreateExecution(sysVars map[string]any, task string) (*base.E
return &base.ExecutionWrapper{Execution: e}, nil
}

func (e *execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, error) {
func (e *execution) Execute(_ context.Context, inputs []*structpb.Struct) ([]*structpb.Struct, error) {
outputs := make([]*structpb.Struct, len(inputs))

// An execution might take several inputs. One result will be returned for
Expand Down Expand Up @@ -549,6 +549,7 @@ import (

func TestOperator_Execute(t *testing.T) {
c := qt.New(t)
ctx := context.Background()

bo := base.BaseOperator{Logger: zap.NewNop()}
operator := Init(bo)
Expand All @@ -560,7 +561,7 @@ func TestOperator_Execute(t *testing.T) {
pbIn, err := structpb.NewStruct(map[string]any{"target": "bolero-wombat"})
c.Assert(err, qt.IsNil)

got, err := exec.Execution.Execute([]*structpb.Struct{pbIn})
got, err := exec.Execution.Execute(ctx, []*structpb.Struct{pbIn})
c.Check(err, qt.IsNil)
c.Assert(got, qt.HasLen, 1)

Expand Down
12 changes: 7 additions & 5 deletions pkg/base/execution.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package base

import (
"context"
"encoding/json"
"fmt"
"strconv"
Expand All @@ -25,11 +26,12 @@ type IExecution interface {
GetLogger() *zap.Logger
GetTaskInputSchema() string
GetTaskOutputSchema() string
GetSystemVariables() map[string]any

UsesSecret() bool
UsageHandlerCreator() UsageHandlerCreator

Execute([]*structpb.Struct) ([]*structpb.Struct, error)
Execute(context.Context, []*structpb.Struct) ([]*structpb.Struct, error)
}

func FormatErrors(inputPath string, e jsonschema.Detailed, errors *[]string) {
Expand Down Expand Up @@ -115,18 +117,18 @@ func Validate(data []*structpb.Struct, jsonSchema string, target string) error {
}

// Execute wraps the execution method with validation and usage collection.
func (e *ExecutionWrapper) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, error) {
func (e *ExecutionWrapper) Execute(ctx context.Context, inputs []*structpb.Struct) ([]*structpb.Struct, error) {
if err := Validate(inputs, e.Execution.GetTaskInputSchema(), "inputs"); err != nil {
return nil, err
}

newUH := e.Execution.UsageHandlerCreator()
h := newUH(e.Execution)
if err := h.Check(inputs); err != nil {
if err := h.Check(ctx, inputs); err != nil {
return nil, err
}

outputs, err := e.Execution.Execute(inputs)
outputs, err := e.Execution.Execute(ctx, inputs)
if err != nil {
return nil, err
}
Expand All @@ -135,7 +137,7 @@ func (e *ExecutionWrapper) Execute(inputs []*structpb.Struct) ([]*structpb.Struc
return nil, err
}

if err := h.Collect(inputs, outputs); err != nil {
if err := h.Collect(ctx, inputs, outputs); err != nil {
return nil, err
}

Expand Down
14 changes: 9 additions & 5 deletions pkg/base/usage.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
package base

import "google.golang.org/protobuf/types/known/structpb"
import (
"context"

"google.golang.org/protobuf/types/known/structpb"
)

// UsageHandler allows the component execution wrapper to add checks and
// collect usage metrics around a component execution.
type UsageHandler interface {
Check(inputs []*structpb.Struct) error
Collect(inputs, outputs []*structpb.Struct) error
Check(ctx context.Context, inputs []*structpb.Struct) error
Collect(ctx context.Context, inputs, outputs []*structpb.Struct) error
}

// UsageHandlerCreator returns a function to initialize a UsageHandler.
type UsageHandlerCreator func(IExecution) UsageHandler

type noopUsageHandler struct{}

func (h *noopUsageHandler) Check([]*structpb.Struct) error { return nil }
func (h *noopUsageHandler) Collect(_, _ []*structpb.Struct) error {
func (h *noopUsageHandler) Check(context.Context, []*structpb.Struct) error { return nil }
func (h *noopUsageHandler) Collect(_ context.Context, _, _ []*structpb.Struct) error {
return nil
}

Expand Down
5 changes: 3 additions & 2 deletions pkg/connector/airbyte/v0/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package airbyte

import (
"context"
_ "embed"
"fmt"
"sync"
Expand Down Expand Up @@ -44,10 +45,10 @@ func (c *connector) CreateExecution(sysVars map[string]any, connection *structpb
}}, nil
}

func (e *execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, error) {
func (e *execution) Execute(context.Context, []*structpb.Struct) ([]*structpb.Struct, error) {
return nil, fmt.Errorf("the Airbyte connector has been removed")
}

func (c *connector) Test(sysVars map[string]any, connection *structpb.Struct) error {
func (c *connector) Test(map[string]any, *structpb.Struct) error {
return fmt.Errorf("the Airbyte connector has been removed")
}
4 changes: 3 additions & 1 deletion pkg/connector/archetypeai/v0/connector_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package archetypeai

import (
"context"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -100,6 +101,7 @@ var (

func TestConnector_Execute(t *testing.T) {
c := qt.New(t)
ctx := context.Background()

testcases := []struct {
name string
Expand Down Expand Up @@ -265,7 +267,7 @@ func TestConnector_Execute(t *testing.T) {
pbIn, err := base.ConvertToStructpb(tc.in)
c.Assert(err, qt.IsNil)

got, err := exec.Execution.Execute([]*structpb.Struct{pbIn})
got, err := exec.Execution.Execute(ctx, []*structpb.Struct{pbIn})
if tc.wantErr != "" {
c.Check(errmsg.Message(err), qt.Matches, tc.wantErr)
return
Expand Down
3 changes: 2 additions & 1 deletion pkg/connector/archetypeai/v0/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package archetypeai

import (
"bytes"
"context"
_ "embed"
"fmt"
"strings"
Expand Down Expand Up @@ -81,7 +82,7 @@ func (c *connector) CreateExecution(sysVars map[string]any, connection *structpb
}

// Execute performs calls the Archetype AI API to execute a task.
func (e *execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, error) {
func (e *execution) Execute(_ context.Context, inputs []*structpb.Struct) ([]*structpb.Struct, error) {
outputs := make([]*structpb.Struct, len(inputs))

for i, input := range inputs {
Expand Down
2 changes: 1 addition & 1 deletion pkg/connector/bigquery/v0/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func getTableName(config *structpb.Struct) string {
return config.GetFields()["table_name"].GetStringValue()
}

func (e *execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, error) {
func (e *execution) Execute(ctx context.Context, inputs []*structpb.Struct) ([]*structpb.Struct, error) {
outputs := []*structpb.Struct{}

client, err := NewClient(getJSONKey(e.Connection), getProjectID(e.Connection))
Expand Down
2 changes: 1 addition & 1 deletion pkg/connector/googlecloudstorage/v0/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func getJSONKey(config *structpb.Struct) string {
return config.GetFields()["json_key"].GetStringValue()
}

func (e *execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, error) {
func (e *execution) Execute(ctx context.Context, inputs []*structpb.Struct) ([]*structpb.Struct, error) {
outputs := []*structpb.Struct{}

client, err := NewClient(getJSONKey(e.Connection))
Expand Down
2 changes: 1 addition & 1 deletion pkg/connector/googlesearch/v0/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func getSearchEngineID(config *structpb.Struct) string {
return config.GetFields()["cse_id"].GetStringValue()
}

func (e *execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, error) {
func (e *execution) Execute(ctx context.Context, inputs []*structpb.Struct) ([]*structpb.Struct, error) {

service, err := NewService(getAPIKey(e.Connection))
if err != nil || service == nil {
Expand Down
6 changes: 4 additions & 2 deletions pkg/connector/huggingface/v0/connector_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package huggingface

import (
"context"
"encoding/base64"
"fmt"
"io"
Expand Down Expand Up @@ -222,6 +223,7 @@ func TestConnector_ExecuteSpeechRecognition(t *testing.T) {
func testTask(c *qt.C, p taskParams) {
bc := base.Connector{Logger: zap.NewNop()}
connector := Init(bc)
ctx := context.Background()

c.Run("nok - HTTP client error - "+p.task, func(c *qt.C) {
c.Parallel()
Expand All @@ -238,7 +240,7 @@ func testTask(c *qt.C, p taskParams) {
c.Assert(err, qt.IsNil)
pbIn.Fields["model"] = structpb.NewStringValue(model)

_, err = exec.Execution.Execute([]*structpb.Struct{pbIn})
_, err = exec.Execution.Execute(ctx, []*structpb.Struct{pbIn})
c.Check(err, qt.IsNotNil)
c.Check(err, qt.ErrorMatches, ".*no such host")
c.Check(errmsg.Message(err), qt.Matches, "Failed to call .*check that the connector configuration is correct.")
Expand Down Expand Up @@ -326,7 +328,7 @@ func testTask(c *qt.C, p taskParams) {
c.Assert(err, qt.IsNil)
pbIn.Fields["model"] = structpb.NewStringValue(model)

got, err := exec.Execution.Execute([]*structpb.Struct{pbIn})
got, err := exec.Execution.Execute(ctx, []*structpb.Struct{pbIn})
if tc.wantErr != "" {
c.Check(err, qt.IsNotNil)
c.Check(errmsg.Message(err), qt.Equals, tc.wantErr)
Expand Down
3 changes: 2 additions & 1 deletion pkg/connector/huggingface/v0/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package huggingface

import (
"context"
_ "embed"
"encoding/base64"
"encoding/json"
Expand Down Expand Up @@ -101,7 +102,7 @@ func wrapSliceInStruct(data []byte, key string) (*structpb.Struct, error) {
}, nil
}

func (e *execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, error) {
func (e *execution) Execute(_ context.Context, inputs []*structpb.Struct) ([]*structpb.Struct, error) {
client := newClient(e.Connection, e.GetLogger())
outputs := []*structpb.Struct{}

Expand Down
4 changes: 2 additions & 2 deletions pkg/connector/instill/v0/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func getRequestMetadata(vars map[string]any) metadata.MD {
)
}

func (e *execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, error) {
func (e *execution) Execute(ctx context.Context, inputs []*structpb.Struct) ([]*structpb.Struct, error) {
var err error

if len(inputs) <= 0 || inputs[0] == nil {
Expand All @@ -116,7 +116,7 @@ func (e *execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, erro
}

modelNameSplits := strings.Split(inputs[0].GetFields()["model_name"].GetStringValue(), "/")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()

ctx = metadata.NewOutgoingContext(ctx, getRequestMetadata(e.SystemVariables))
Expand Down
3 changes: 2 additions & 1 deletion pkg/connector/numbers/v0/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package numbers

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -191,7 +192,7 @@ func (e *execution) registerAsset(data []byte, reg Register) (string, error) {
}
}

func (e *execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, error) {
func (e *execution) Execute(_ context.Context, inputs []*structpb.Struct) ([]*structpb.Struct, error) {

var outputs []*structpb.Struct

Expand Down
Loading

0 comments on commit 80c4780

Please sign in to comment.