Skip to content

Commit

Permalink
Merge pull request #437 from nadinelyab/nadin/stream-interceptor
Browse files Browse the repository at this point in the history
Adds a stream interceptor to keep state between the send and receive calls
  • Loading branch information
bojand committed May 5, 2024
2 parents f778079 + 149d520 commit 6bc85c8
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 22 deletions.
10 changes: 10 additions & 0 deletions runner/data.go
Expand Up @@ -44,6 +44,16 @@ type StreamMessageProviderFunc func(*CallData) (*dynamic.Message, error)
// Clients can return ErrEndStream to end the call early
type StreamRecvMsgInterceptFunc func(*dynamic.Message, error) error

// StreamInterceptorProviderFunc is an interface for a function invoked to generate a stream interceptor
type StreamInterceptorProviderFunc func(*CallData) StreamInterceptor

// StreamInterceptor is an interface for sending and receiving stream messages.
// The interceptor can keep shared state for the send and receive calls.
type StreamInterceptor interface {
Recv(*dynamic.Message, error) error
Send(*CallData) (*dynamic.Message, error)
}

type dataProvider struct {
binary bool
data []byte
Expand Down
22 changes: 16 additions & 6 deletions runner/options.go
Expand Up @@ -129,12 +129,13 @@ type RunConfig struct {
disableTemplateData bool

// misc
name string
cpus int
tags []byte
skipFirst int
countErrors bool
recvMsgFunc StreamRecvMsgInterceptFunc
name string
cpus int
tags []byte
skipFirst int
countErrors bool
recvMsgFunc StreamRecvMsgInterceptFunc
streamInterceptorProviderFunc StreamInterceptorProviderFunc
}

// Option controls some aspect of run
Expand Down Expand Up @@ -1034,6 +1035,15 @@ func WithStreamRecvMsgIntercept(fn StreamRecvMsgInterceptFunc) Option {
}
}

// WithStreamInterceptor specifies the stream interceptor provider function
func WithStreamInterceptorProviderFunc(interceptor StreamInterceptorProviderFunc) Option {
return func(o *RunConfig) error {
o.streamInterceptorProviderFunc = interceptor

return nil
}
}

// WithDataProvider provides custom data provider
//
// WithDataProvider(func(*CallData) ([]*dynamic.Message, error) {
Expand Down
23 changes: 12 additions & 11 deletions runner/requester.go
Expand Up @@ -389,17 +389,18 @@ func (b *Requester) runWorkers(wt load.WorkerTicker, p load.Pacer) error {
}

w := Worker{
ticks: ticks,
active: true,
stub: b.stubs[n],
mtd: b.mtd,
config: b.config,
stopCh: make(chan bool),
workerID: wID,
dataProvider: b.dataProvider,
metadataProvider: b.metadataProvider,
streamRecv: b.config.recvMsgFunc,
msgProvider: b.config.dataStreamFunc,
ticks: ticks,
active: true,
stub: b.stubs[n],
mtd: b.mtd,
config: b.config,
stopCh: make(chan bool),
workerID: wID,
dataProvider: b.dataProvider,
metadataProvider: b.metadataProvider,
streamRecv: b.config.recvMsgFunc,
msgProvider: b.config.dataStreamFunc,
streamInterceptorProviderFunc: b.config.streamInterceptorProviderFunc,
}

wc++ // increment worker id
Expand Down
45 changes: 40 additions & 5 deletions runner/worker.go
Expand Up @@ -40,7 +40,8 @@ type Worker struct {
metadataProvider MetadataProviderFunc
msgProvider StreamMessageProviderFunc

streamRecv StreamRecvMsgInterceptFunc
streamRecv StreamRecvMsgInterceptFunc
streamInterceptorProviderFunc StreamInterceptorProviderFunc
}

func (w *Worker) runWorker() error {
Expand Down Expand Up @@ -83,6 +84,13 @@ func (w *Worker) makeRequest(tv TickValue) error {

ctd := newCallData(w.mtd, w.workerID, reqNum, !w.config.disableTemplateFuncs, !w.config.disableTemplateData, w.config.funcs)

var streamInterceptor StreamInterceptor
if w.mtd.IsClientStreaming() || w.mtd.IsServerStreaming() {
if w.streamInterceptorProviderFunc != nil {
streamInterceptor = w.streamInterceptorProviderFunc(ctd)
}
}

reqMD, err := w.metadataProvider(ctd)
if err != nil {
return err
Expand Down Expand Up @@ -115,6 +123,8 @@ func (w *Worker) makeRequest(tv TickValue) error {
var msgProvider StreamMessageProviderFunc
if w.msgProvider != nil {
msgProvider = w.msgProvider
} else if streamInterceptor != nil {
msgProvider = streamInterceptor.Send
} else if w.mtd.IsClientStreaming() {
if w.config.streamDynamicMessages {
mp, err := newDynamicMessageProvider(w.mtd, w.config.data, w.config.streamCallCount, !w.config.disableTemplateFuncs, !w.config.disableTemplateData)
Expand Down Expand Up @@ -155,11 +165,11 @@ func (w *Worker) makeRequest(tv TickValue) error {

// RPC errors are handled via stats handler
if w.mtd.IsClientStreaming() && w.mtd.IsServerStreaming() {
_ = w.makeBidiRequest(&ctx, ctd, msgProvider)
_ = w.makeBidiRequest(&ctx, ctd, msgProvider, streamInterceptor)
} else if w.mtd.IsClientStreaming() {
_ = w.makeClientStreamingRequest(&ctx, ctd, msgProvider)
} else if w.mtd.IsServerStreaming() {
_ = w.makeServerStreamingRequest(&ctx, inputs[0])
_ = w.makeServerStreamingRequest(&ctx, inputs[0], streamInterceptor)
} else {
_ = w.makeUnaryRequest(&ctx, reqMD, inputs[0])
}
Expand Down Expand Up @@ -314,7 +324,7 @@ func (w *Worker) makeClientStreamingRequest(ctx *context.Context,
return nil
}

func (w *Worker) makeServerStreamingRequest(ctx *context.Context, input *dynamic.Message) error {
func (w *Worker) makeServerStreamingRequest(ctx *context.Context, input *dynamic.Message, streamInterceptor StreamInterceptor) error {
var callOptions = []grpc.CallOption{}
if w.config.enableCompression {
callOptions = append(callOptions, grpc.UseCompressor(gzip.Name))
Expand Down Expand Up @@ -388,6 +398,18 @@ func (w *Worker) makeServerStreamingRequest(ctx *context.Context, input *dynamic
}
}

if streamInterceptor != nil {
if converted, ok := res.(*dynamic.Message); ok {
err = streamInterceptor.Recv(converted, err)
if errors.Is(err, ErrEndStream) && !interceptCanceled {
interceptCanceled = true
err = nil

callCancel()
}
}
}

if err != nil {
if err == io.EOF {
err = nil
Expand Down Expand Up @@ -415,7 +437,7 @@ func (w *Worker) makeServerStreamingRequest(ctx *context.Context, input *dynamic
}

func (w *Worker) makeBidiRequest(ctx *context.Context,
ctd *CallData, messageProvider StreamMessageProviderFunc) error {
ctd *CallData, messageProvider StreamMessageProviderFunc, streamInterceptor StreamInterceptor) error {

var callOptions = []grpc.CallOption{}

Expand Down Expand Up @@ -494,6 +516,19 @@ func (w *Worker) makeBidiRequest(ctx *context.Context,
}
}

if streamInterceptor != nil {
if converted, ok := res.(*dynamic.Message); ok {
iErr := streamInterceptor.Recv(converted, recvErr)
if errors.Is(iErr, ErrEndStream) && !interceptCanceled {
interceptCanceled = true
if len(cancel) == 0 {
cancel <- struct{}{}
}
recvErr = nil
}
}
}

if recvErr != nil {
close(recvDone)
break
Expand Down

0 comments on commit 6bc85c8

Please sign in to comment.