From 72cd99955d85202b0a1a15687242176c650e4d7c Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Fri, 19 May 2023 09:17:01 +0200 Subject: [PATCH] chore: add pre callbacks (#68) --- examples/pipeline/callbacks/main.go | 18 ++++---- pipeline/pipeline.go | 67 ++++++++++++++++++++++------- 2 files changed, 59 insertions(+), 26 deletions(-) diff --git a/examples/pipeline/callbacks/main.go b/examples/pipeline/callbacks/main.go index 728848fe..2713cfbc 100644 --- a/examples/pipeline/callbacks/main.go +++ b/examples/pipeline/callbacks/main.go @@ -31,26 +31,24 @@ func main() { }, ) - cbTranslate := pipeline.PipelineCallback(func(output types.M) (types.M, error) { - iterator++ - return output, nil + translatePreCallback := pipeline.PipelineCallback(func(input types.M) (types.M, error) { + input["language"] = languages[iterator] + input["sentence"] = sentence + return input, nil }) - cbExpand := pipeline.PipelineCallback(func(output types.M) (types.M, error) { - + expandPostCallback := pipeline.PipelineCallback(func(output types.M) (types.M, error) { + iterator++ if iterator >= len(languages) { pipeline.SetNextTubeExit(output) } else { pipeline.SetNextTube(output, 0) - output["language"] = languages[iterator] - output["sentence"] = sentence } - return output, nil }) - pipeLine := pipeline.New(translate, expand).WithCallbacks(cbTranslate, cbExpand) + pipeLine := pipeline.New(translate, expand).WithPreCallbacks(translatePreCallback, nil).WithPostCallbacks(nil, expandPostCallback) - pipeLine.Run(context.Background(), types.M{"sentence": sentence, "language": languages[iterator]}) + pipeLine.Run(context.Background(), nil) } diff --git a/pipeline/pipeline.go b/pipeline/pipeline.go index d7e6232f..f59398c2 100644 --- a/pipeline/pipeline.go +++ b/pipeline/pipeline.go @@ -31,35 +31,64 @@ type Pipe interface { Run(ctx context.Context, input types.M) (types.M, error) } -type PipelineCallback func(input types.M) (types.M, error) +type PipelineCallback func(values types.M) (types.M, error) type pipeline struct { - pipes []Pipe - callbacks []PipelineCallback + pipes map[int]Pipe + preCallbacks map[int]PipelineCallback + postCallbacks map[int]PipelineCallback } func New(pipes ...Pipe) *pipeline { + + pipesMap := make(map[int]Pipe) + for i, pipe := range pipes { + pipesMap[i] = pipe + } + return &pipeline{ - pipes: pipes, + pipes: pipesMap, } } -func (p *pipeline) WithCallbacks(callbacks ...PipelineCallback) pipeline { - p.callbacks = callbacks - return *p +func (p *pipeline) WithPreCallbacks(callbacks ...PipelineCallback) *pipeline { + + p.preCallbacks = make(map[int]PipelineCallback) + for i, callback := range callbacks { + p.preCallbacks[i] = callback + } + + return p +} + +func (p *pipeline) WithPostCallbacks(callbacks ...PipelineCallback) *pipeline { + + p.postCallbacks = make(map[int]PipelineCallback) + for i, callback := range callbacks { + p.postCallbacks[i] = callback + } + + return p } // Run chains the steps of the pipeline and returns the output of the last step. func (p pipeline) Run(ctx context.Context, input types.M) (types.M, error) { var err error - var output types.M - currentTube := -1 + currentTube := 0 + + if input == nil { + input = types.M{} + } + + output := input for { - if currentTube == -1 { - currentTube = 0 - output = input + if p.thereIsAValidPreCallbackForTube(currentTube) { + output, err = p.preCallbacks[currentTube](output) + if err != nil { + return nil, err + } } output, err = p.pipes[currentTube].Run(ctx, output) @@ -67,8 +96,8 @@ func (p pipeline) Run(ctx context.Context, input types.M) (types.M, error) { return nil, err } - if p.thereIsAValidCallbackForTube(currentTube) { - output, err = p.callbacks[currentTube](output) + if p.thereIsAValidPostCallbackForTube(currentTube) { + output, err = p.postCallbacks[currentTube](output) if err != nil { return nil, err } @@ -105,8 +134,14 @@ func SetNextTubeExit(output types.M) types.M { return output } -func (p *pipeline) thereIsAValidCallbackForTube(currentTube int) bool { - return len(p.callbacks) == len(p.pipes) && p.callbacks[currentTube] != nil +func (p *pipeline) thereIsAValidPreCallbackForTube(currentTube int) bool { + cb, ok := p.preCallbacks[currentTube] + return cb != nil && ok +} + +func (p *pipeline) thereIsAValidPostCallbackForTube(currentTube int) bool { + cb, ok := p.postCallbacks[currentTube] + return cb != nil && ok } func (p *pipeline) getNextTube(output types.M) *int {