Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: handle stop run and skip polling during polling step #3552

Merged
merged 10 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1,018 changes: 514 additions & 504 deletions agent/proto/orchestrator.pb.go

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions agent/proto/orchestrator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ message Empty {}
message StopRequest{
string testID = 1;
int32 runID = 2;
string type = 3;
}

// ConnectRequest is the initial request sent by the agent to the orchestrator
Expand Down
2 changes: 1 addition & 1 deletion agent/proto/orchestrator_grpc.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions agent/runner/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,24 +111,25 @@ func newControlPlaneClient(ctx context.Context, config config.Config, traceCache
return nil, err
}

cancelFuncs := workers.NewCancelFuncMap()
processStopper := workers.NewProcessStopper()

stopWorker := workers.NewStopperWorker(
workers.WithStopperObserver(observer),
workers.WithStopperCancelFuncList(cancelFuncs),
workers.WithStopperCancelFuncList(processStopper.CancelMap()),
)

triggerWorker := workers.NewTriggerWorker(
controlPlaneClient,
workers.WithTraceCache(traceCache),
workers.WithTriggerObserver(observer),
workers.WithTriggerCancelFuncList(cancelFuncs),
workers.WithTriggerStoppableProcessRunner(processStopper.RunStoppableProcess),
)

pollingWorker := workers.NewPollerWorker(
controlPlaneClient,
workers.WithInMemoryDatastore(poller.NewInMemoryDatastore(traceCache)),
workers.WithObserver(observer),
workers.WithPollerStoppableProcessRunner(processStopper.RunStoppableProcess),
)

dataStoreTestConnectionWorker := workers.NewTestConnectionWorker(controlPlaneClient, observer)
Expand Down
46 changes: 0 additions & 46 deletions agent/workers/cancel_func_map.go

This file was deleted.

31 changes: 24 additions & 7 deletions agent/workers/poller.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/kubeshop/tracetest/agent/event"
"github.com/kubeshop/tracetest/agent/proto"
"github.com/kubeshop/tracetest/server/datastore"
"github.com/kubeshop/tracetest/server/executor"
"github.com/kubeshop/tracetest/server/tracedb"
"github.com/kubeshop/tracetest/server/tracedb/connection"
"github.com/kubeshop/tracetest/server/traces"
Expand All @@ -21,12 +22,13 @@ import (
)

type PollerWorker struct {
client *client.Client
tracer trace.Tracer
sentSpanIDs *gocache.Cache[string, bool]
inmemoryDatastore tracedb.TraceDB
logger *zap.Logger
observer event.Observer
client *client.Client
tracer trace.Tracer
sentSpanIDs *gocache.Cache[string, bool]
inmemoryDatastore tracedb.TraceDB
logger *zap.Logger
observer event.Observer
stoppableProcessRunner StoppableProcessRunner
}

type PollerOption func(*PollerWorker)
Expand All @@ -43,6 +45,12 @@ func WithObserver(observer event.Observer) PollerOption {
}
}

func WithPollerStoppableProcessRunner(stoppableProcessRunner StoppableProcessRunner) PollerOption {
return func(pw *PollerWorker) {
pw.stoppableProcessRunner = stoppableProcessRunner
}
}

func NewPollerWorker(client *client.Client, opts ...PollerOption) *PollerWorker {
// TODO: use a real tracer
tracer := trace.NewNoopTracerProvider().Tracer("noop")
Expand Down Expand Up @@ -70,7 +78,16 @@ func (w *PollerWorker) Poll(ctx context.Context, request *proto.PollingRequest)
w.logger.Debug("Received polling request", zap.Any("request", request))
w.observer.StartTracePoll(request)

err := w.poll(ctx, request)
var err error
w.stoppableProcessRunner(ctx, request.TestID, request.RunID, func(ctx context.Context) {
err = w.poll(ctx, request)
}, func(cause string) {
err = executor.ErrUserCancelled
if cause == string(executor.UserRequestTypeSkipTraceCollection) {
err = executor.ErrSkipTraceCollection
}
})

if err != nil {
w.logger.Error("Error polling", zap.Error(err))
errorResponse := &proto.PollingResponse{
Expand Down
12 changes: 7 additions & 5 deletions agent/workers/poller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func TestPollerWorker(t *testing.T) {
client, err := client.Connect(ctx, controlPlane.Addr())
require.NoError(t, err)

pollerWorker := workers.NewPollerWorker(client)
pollerWorker := workers.NewPollerWorker(client, workers.WithPollerStoppableProcessRunner(workers.NewProcessStopper().RunStoppableProcess))

client.OnPollingRequest(func(ctx context.Context, pr *proto.PollingRequest) error {
return pollerWorker.Poll(ctx, pr)
Expand Down Expand Up @@ -125,7 +125,9 @@ func TestPollerWorkerWithInmemoryDatastore(t *testing.T) {

pollerWorker := workers.NewPollerWorker(client, workers.WithInMemoryDatastore(
poller.NewInMemoryDatastore(cache),
))
),
workers.WithPollerStoppableProcessRunner(workers.NewProcessStopper().RunStoppableProcess),
)

client.OnPollingRequest(func(ctx context.Context, pr *proto.PollingRequest) error {
return pollerWorker.Poll(ctx, pr)
Expand Down Expand Up @@ -182,7 +184,7 @@ func TestPollerWithInvalidDataStore(t *testing.T) {
client, err := client.Connect(ctx, controlPlane.Addr())
require.NoError(t, err)

pollerWorker := workers.NewPollerWorker(client)
pollerWorker := workers.NewPollerWorker(client, workers.WithPollerStoppableProcessRunner(workers.NewProcessStopper().RunStoppableProcess))

client.OnPollingRequest(func(ctx context.Context, pr *proto.PollingRequest) error {
return pollerWorker.Poll(ctx, pr)
Expand All @@ -200,7 +202,7 @@ func TestPollerWithInvalidDataStore(t *testing.T) {
Tempo: &proto.TempoConfig{
Type: "http",
Http: &proto.HttpClientSettings{
Url: "http://localhost:16686", // invalid jaeger port, this should cause an error
Url: "http://localhost:12312", // invalid tempo port, this should cause an error
},
},
},
Expand All @@ -212,6 +214,6 @@ func TestPollerWithInvalidDataStore(t *testing.T) {

pollingResponse := controlPlane.GetLastPollingResponse()
require.NotNil(t, pollingResponse, "agent did not send polling response back to server")
assert.NotNil(t, pollingResponse.Error)
require.NotNil(t, pollingResponse.Error)
assert.Contains(t, pollingResponse.Error.Message, "connection refused")
}
92 changes: 92 additions & 0 deletions agent/workers/process_stopper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package workers

import (
"context"
"sync"

gocache "github.com/Code-Hex/go-generics-cache"
)

type StoppableProcessRunner func(parentCtx context.Context, testID string, runID int32, worker func(context.Context), stopedCallback func(cause string))

func NewProcessStopper() processStopper {
return processStopper{
cancelMap: newCancelCauseFuncMap(),
}
}

type processStopper struct {
cancelMap *cancelCauseFuncMap
}

func (p processStopper) CancelMap() *cancelCauseFuncMap {
return p.cancelMap
}

func (p processStopper) RunStoppableProcess(parentCtx context.Context, testID string, runID int32, worker func(context.Context), stopedCallback func(cause string)) {
done := make(chan bool)

// create a subcontext for the worker so when canceled it doesn't affect the parent context
subcontext, cancelSubctx := context.WithCancelCause(parentCtx)
defer cancelSubctx(nil)

cacheKey := key(testID, runID)
p.cancelMap.Set(cacheKey, cancelSubctx)
defer p.cancelMap.Del(cacheKey)

go func() {
worker(subcontext)
done <- true
}()

select {
case <-done:
// trigger finished successfully
break
case <-subcontext.Done():
cause := "cancelled"
if err := context.Cause(subcontext); err != nil {
cause = err.Error()
}
stopedCallback(cause)
break
}
}

func key(testID string, runID int32) string {
return testID + string(runID)
}

type cancelCauseFuncMap struct {
mutex sync.Mutex
internalMap *gocache.Cache[string, context.CancelCauseFunc]
}

// Get implements TraceCache.
func (c *cancelCauseFuncMap) Get(key string) (context.CancelCauseFunc, bool) {
c.mutex.Lock()
defer c.mutex.Unlock()

return c.internalMap.Get(key)
}

// Append implements TraceCache.
func (c *cancelCauseFuncMap) Set(key string, cancelFn context.CancelCauseFunc) {
c.mutex.Lock()
defer c.mutex.Unlock()

c.internalMap.Set(key, cancelFn)
}

func (c *cancelCauseFuncMap) Del(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()

c.internalMap.Delete(key)
}

func newCancelCauseFuncMap() *cancelCauseFuncMap {
return &cancelCauseFuncMap{
internalMap: gocache.New[string, context.CancelCauseFunc](),
}
}
7 changes: 4 additions & 3 deletions agent/workers/stopper.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package workers

import (
"context"
"errors"
"fmt"

"github.com/kubeshop/tracetest/agent/event"
Expand All @@ -12,12 +13,12 @@ import (
type StopperWorker struct {
logger *zap.Logger
observer event.Observer
cancelContexts *cancelFuncMap
cancelContexts *cancelCauseFuncMap
}

type StopperOption func(*StopperWorker)

func WithStopperCancelFuncList(cancelContexts *cancelFuncMap) StopperOption {
func WithStopperCancelFuncList(cancelContexts *cancelCauseFuncMap) StopperOption {
return func(tw *StopperWorker) {
tw.cancelContexts = cancelContexts
}
Expand Down Expand Up @@ -59,7 +60,7 @@ func (w *StopperWorker) Stop(ctx context.Context, stopRequest *proto.StopRequest
return err
}

cancelFn()
cancelFn(errors.New(stopRequest.Type))

w.observer.EndStopRequest(stopRequest, nil)

Expand Down
40 changes: 11 additions & 29 deletions agent/workers/trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@ import (
)

type TriggerWorker struct {
logger *zap.Logger
client *client.Client
registry *agentTrigger.Registry
traceCache collector.TraceCache
observer event.Observer
cancelMap *cancelFuncMap
logger *zap.Logger
client *client.Client
registry *agentTrigger.Registry
traceCache collector.TraceCache
observer event.Observer
stoppableProcessRunner StoppableProcessRunner
}

type TriggerOption func(*TriggerWorker)

func WithTriggerCancelFuncList(cancelContexts *cancelFuncMap) TriggerOption {
func WithTriggerStoppableProcessRunner(stoppableProcessRunner StoppableProcessRunner) TriggerOption {
return func(tw *TriggerWorker) {
tw.cancelMap = cancelContexts
tw.stoppableProcessRunner = stoppableProcessRunner
}
}

Expand Down Expand Up @@ -78,30 +78,12 @@ func (w *TriggerWorker) Trigger(ctx context.Context, triggerRequest *proto.Trigg
w.logger.Debug("Trigger request received", zap.Any("triggerRequest", triggerRequest))
w.observer.StartTriggerExecution(triggerRequest)

done := make(chan bool)

subcontext, cancelSubctx := context.WithCancel(ctx)
defer cancelSubctx()

cacheKey := key(triggerRequest.TestID, triggerRequest.RunID)
w.cancelMap.Set(cacheKey, cancelSubctx)
defer w.cancelMap.Del(cacheKey)

var err error
go func() {
w.stoppableProcessRunner(ctx, triggerRequest.TestID, triggerRequest.RunID, func(subcontext context.Context) {
err = w.trigger(subcontext, triggerRequest)
done <- true
}()

select {
case <-done:
// trigger finished successfully
break
case <-subcontext.Done():
// The context was cancelled.
}, func(_ string) {
err = executor.ErrUserCancelled
break
}
})

if err != nil {
w.logger.Error("Trigger error", zap.Error(err))
Expand Down