Skip to content

Commit

Permalink
feat: handle stop run and skip polling during polling step (#3552)
Browse files Browse the repository at this point in the history
  • Loading branch information
schoren committed Jan 22, 2024
1 parent 217261a commit 2d6b79f
Show file tree
Hide file tree
Showing 15 changed files with 734 additions and 646 deletions.
1,018 changes: 514 additions & 504 deletions agent/proto/orchestrator.pb.go

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions agent/proto/orchestrator.proto
Expand Up @@ -53,6 +53,7 @@ message Empty {}
message StopRequest{
string testID = 1;
int32 runID = 2;
string type = 3;
}

// ConnectRequest is the initial request sent by the agent to the orchestrator
Expand Down
2 changes: 1 addition & 1 deletion agent/proto/orchestrator_grpc.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions agent/runner/session.go
Expand Up @@ -111,24 +111,25 @@ func newControlPlaneClient(ctx context.Context, config config.Config, traceCache
return nil, err
}

cancelFuncs := workers.NewCancelFuncMap()
processStopper := workers.NewProcessStopper()

stopWorker := workers.NewStopperWorker(
workers.WithStopperObserver(observer),
workers.WithStopperCancelFuncList(cancelFuncs),
workers.WithStopperCancelFuncList(processStopper.CancelMap()),
)

triggerWorker := workers.NewTriggerWorker(
controlPlaneClient,
workers.WithTraceCache(traceCache),
workers.WithTriggerObserver(observer),
workers.WithTriggerCancelFuncList(cancelFuncs),
workers.WithTriggerStoppableProcessRunner(processStopper.RunStoppableProcess),
)

pollingWorker := workers.NewPollerWorker(
controlPlaneClient,
workers.WithInMemoryDatastore(poller.NewInMemoryDatastore(traceCache)),
workers.WithObserver(observer),
workers.WithPollerStoppableProcessRunner(processStopper.RunStoppableProcess),
)

dataStoreTestConnectionWorker := workers.NewTestConnectionWorker(controlPlaneClient, observer)
Expand Down
46 changes: 0 additions & 46 deletions agent/workers/cancel_func_map.go

This file was deleted.

31 changes: 24 additions & 7 deletions agent/workers/poller.go
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/kubeshop/tracetest/agent/event"
"github.com/kubeshop/tracetest/agent/proto"
"github.com/kubeshop/tracetest/server/datastore"
"github.com/kubeshop/tracetest/server/executor"
"github.com/kubeshop/tracetest/server/tracedb"
"github.com/kubeshop/tracetest/server/tracedb/connection"
"github.com/kubeshop/tracetest/server/traces"
Expand All @@ -21,12 +22,13 @@ import (
)

type PollerWorker struct {
client *client.Client
tracer trace.Tracer
sentSpanIDs *gocache.Cache[string, bool]
inmemoryDatastore tracedb.TraceDB
logger *zap.Logger
observer event.Observer
client *client.Client
tracer trace.Tracer
sentSpanIDs *gocache.Cache[string, bool]
inmemoryDatastore tracedb.TraceDB
logger *zap.Logger
observer event.Observer
stoppableProcessRunner StoppableProcessRunner
}

type PollerOption func(*PollerWorker)
Expand All @@ -43,6 +45,12 @@ func WithObserver(observer event.Observer) PollerOption {
}
}

func WithPollerStoppableProcessRunner(stoppableProcessRunner StoppableProcessRunner) PollerOption {
return func(pw *PollerWorker) {
pw.stoppableProcessRunner = stoppableProcessRunner
}
}

func NewPollerWorker(client *client.Client, opts ...PollerOption) *PollerWorker {
// TODO: use a real tracer
tracer := trace.NewNoopTracerProvider().Tracer("noop")
Expand Down Expand Up @@ -70,7 +78,16 @@ func (w *PollerWorker) Poll(ctx context.Context, request *proto.PollingRequest)
w.logger.Debug("Received polling request", zap.Any("request", request))
w.observer.StartTracePoll(request)

err := w.poll(ctx, request)
var err error
w.stoppableProcessRunner(ctx, request.TestID, request.RunID, func(ctx context.Context) {
err = w.poll(ctx, request)
}, func(cause string) {
err = executor.ErrUserCancelled
if cause == string(executor.UserRequestTypeSkipTraceCollection) {
err = executor.ErrSkipTraceCollection
}
})

if err != nil {
w.logger.Error("Error polling", zap.Error(err))
errorResponse := &proto.PollingResponse{
Expand Down
12 changes: 7 additions & 5 deletions agent/workers/poller_test.go
Expand Up @@ -26,7 +26,7 @@ func TestPollerWorker(t *testing.T) {
client, err := client.Connect(ctx, controlPlane.Addr())
require.NoError(t, err)

pollerWorker := workers.NewPollerWorker(client)
pollerWorker := workers.NewPollerWorker(client, workers.WithPollerStoppableProcessRunner(workers.NewProcessStopper().RunStoppableProcess))

client.OnPollingRequest(func(ctx context.Context, pr *proto.PollingRequest) error {
return pollerWorker.Poll(ctx, pr)
Expand Down Expand Up @@ -125,7 +125,9 @@ func TestPollerWorkerWithInmemoryDatastore(t *testing.T) {

pollerWorker := workers.NewPollerWorker(client, workers.WithInMemoryDatastore(
poller.NewInMemoryDatastore(cache),
))
),
workers.WithPollerStoppableProcessRunner(workers.NewProcessStopper().RunStoppableProcess),
)

client.OnPollingRequest(func(ctx context.Context, pr *proto.PollingRequest) error {
return pollerWorker.Poll(ctx, pr)
Expand Down Expand Up @@ -182,7 +184,7 @@ func TestPollerWithInvalidDataStore(t *testing.T) {
client, err := client.Connect(ctx, controlPlane.Addr())
require.NoError(t, err)

pollerWorker := workers.NewPollerWorker(client)
pollerWorker := workers.NewPollerWorker(client, workers.WithPollerStoppableProcessRunner(workers.NewProcessStopper().RunStoppableProcess))

client.OnPollingRequest(func(ctx context.Context, pr *proto.PollingRequest) error {
return pollerWorker.Poll(ctx, pr)
Expand All @@ -200,7 +202,7 @@ func TestPollerWithInvalidDataStore(t *testing.T) {
Tempo: &proto.TempoConfig{
Type: "http",
Http: &proto.HttpClientSettings{
Url: "http://localhost:16686", // invalid jaeger port, this should cause an error
Url: "http://localhost:12312", // invalid tempo port, this should cause an error
},
},
},
Expand All @@ -212,6 +214,6 @@ func TestPollerWithInvalidDataStore(t *testing.T) {

pollingResponse := controlPlane.GetLastPollingResponse()
require.NotNil(t, pollingResponse, "agent did not send polling response back to server")
assert.NotNil(t, pollingResponse.Error)
require.NotNil(t, pollingResponse.Error)
assert.Contains(t, pollingResponse.Error.Message, "connection refused")
}
92 changes: 92 additions & 0 deletions agent/workers/process_stopper.go
@@ -0,0 +1,92 @@
package workers

import (
"context"
"sync"

gocache "github.com/Code-Hex/go-generics-cache"
)

type StoppableProcessRunner func(parentCtx context.Context, testID string, runID int32, worker func(context.Context), stopedCallback func(cause string))

func NewProcessStopper() processStopper {
return processStopper{
cancelMap: newCancelCauseFuncMap(),
}
}

type processStopper struct {
cancelMap *cancelCauseFuncMap
}

func (p processStopper) CancelMap() *cancelCauseFuncMap {
return p.cancelMap
}

func (p processStopper) RunStoppableProcess(parentCtx context.Context, testID string, runID int32, worker func(context.Context), stopedCallback func(cause string)) {
done := make(chan bool)

// create a subcontext for the worker so when canceled it doesn't affect the parent context
subcontext, cancelSubctx := context.WithCancelCause(parentCtx)
defer cancelSubctx(nil)

cacheKey := key(testID, runID)
p.cancelMap.Set(cacheKey, cancelSubctx)
defer p.cancelMap.Del(cacheKey)

go func() {
worker(subcontext)
done <- true
}()

select {
case <-done:
// trigger finished successfully
break
case <-subcontext.Done():
cause := "cancelled"
if err := context.Cause(subcontext); err != nil {
cause = err.Error()
}
stopedCallback(cause)
break
}
}

func key(testID string, runID int32) string {
return testID + string(runID)
}

type cancelCauseFuncMap struct {
mutex sync.Mutex
internalMap *gocache.Cache[string, context.CancelCauseFunc]
}

// Get implements TraceCache.
func (c *cancelCauseFuncMap) Get(key string) (context.CancelCauseFunc, bool) {
c.mutex.Lock()
defer c.mutex.Unlock()

return c.internalMap.Get(key)
}

// Append implements TraceCache.
func (c *cancelCauseFuncMap) Set(key string, cancelFn context.CancelCauseFunc) {
c.mutex.Lock()
defer c.mutex.Unlock()

c.internalMap.Set(key, cancelFn)
}

func (c *cancelCauseFuncMap) Del(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()

c.internalMap.Delete(key)
}

func newCancelCauseFuncMap() *cancelCauseFuncMap {
return &cancelCauseFuncMap{
internalMap: gocache.New[string, context.CancelCauseFunc](),
}
}
7 changes: 4 additions & 3 deletions agent/workers/stopper.go
Expand Up @@ -2,6 +2,7 @@ package workers

import (
"context"
"errors"
"fmt"

"github.com/kubeshop/tracetest/agent/event"
Expand All @@ -12,12 +13,12 @@ import (
type StopperWorker struct {
logger *zap.Logger
observer event.Observer
cancelContexts *cancelFuncMap
cancelContexts *cancelCauseFuncMap
}

type StopperOption func(*StopperWorker)

func WithStopperCancelFuncList(cancelContexts *cancelFuncMap) StopperOption {
func WithStopperCancelFuncList(cancelContexts *cancelCauseFuncMap) StopperOption {
return func(tw *StopperWorker) {
tw.cancelContexts = cancelContexts
}
Expand Down Expand Up @@ -59,7 +60,7 @@ func (w *StopperWorker) Stop(ctx context.Context, stopRequest *proto.StopRequest
return err
}

cancelFn()
cancelFn(errors.New(stopRequest.Type))

w.observer.EndStopRequest(stopRequest, nil)

Expand Down
40 changes: 11 additions & 29 deletions agent/workers/trigger.go
Expand Up @@ -18,19 +18,19 @@ import (
)

type TriggerWorker struct {
logger *zap.Logger
client *client.Client
registry *agentTrigger.Registry
traceCache collector.TraceCache
observer event.Observer
cancelMap *cancelFuncMap
logger *zap.Logger
client *client.Client
registry *agentTrigger.Registry
traceCache collector.TraceCache
observer event.Observer
stoppableProcessRunner StoppableProcessRunner
}

type TriggerOption func(*TriggerWorker)

func WithTriggerCancelFuncList(cancelContexts *cancelFuncMap) TriggerOption {
func WithTriggerStoppableProcessRunner(stoppableProcessRunner StoppableProcessRunner) TriggerOption {
return func(tw *TriggerWorker) {
tw.cancelMap = cancelContexts
tw.stoppableProcessRunner = stoppableProcessRunner
}
}

Expand Down Expand Up @@ -78,30 +78,12 @@ func (w *TriggerWorker) Trigger(ctx context.Context, triggerRequest *proto.Trigg
w.logger.Debug("Trigger request received", zap.Any("triggerRequest", triggerRequest))
w.observer.StartTriggerExecution(triggerRequest)

done := make(chan bool)

subcontext, cancelSubctx := context.WithCancel(ctx)
defer cancelSubctx()

cacheKey := key(triggerRequest.TestID, triggerRequest.RunID)
w.cancelMap.Set(cacheKey, cancelSubctx)
defer w.cancelMap.Del(cacheKey)

var err error
go func() {
w.stoppableProcessRunner(ctx, triggerRequest.TestID, triggerRequest.RunID, func(subcontext context.Context) {
err = w.trigger(subcontext, triggerRequest)
done <- true
}()

select {
case <-done:
// trigger finished successfully
break
case <-subcontext.Done():
// The context was cancelled.
}, func(_ string) {
err = executor.ErrUserCancelled
break
}
})

if err != nil {
w.logger.Error("Trigger error", zap.Error(err))
Expand Down

0 comments on commit 2d6b79f

Please sign in to comment.