diff --git a/server/app/app.go b/server/app/app.go index bee6bb7e48..dc55bd9f13 100644 --- a/server/app/app.go +++ b/server/app/app.go @@ -11,6 +11,7 @@ import ( "github.com/gorilla/handlers" "github.com/gorilla/mux" + "github.com/jackc/pgx/v5/pgxpool" "github.com/kubeshop/tracetest/server/analytics" "github.com/kubeshop/tracetest/server/assertions/comparator" "github.com/kubeshop/tracetest/server/config" @@ -140,6 +141,16 @@ func (app *App) Start(opts ...appOption) error { fmt.Println("Starting") ctx := context.Background() + poolcfg, err := pgxpool.ParseConfig(app.cfg.PostgresConnString()) + if err != nil { + return err + } + + pool, err := pgxpool.NewWithConfig(context.Background(), poolcfg) + if err != nil { + return err + } + db, err := testdb.Connect(app.cfg.PostgresConnString()) if err != nil { return err @@ -212,6 +223,7 @@ func (app *App) Start(opts ...appOption) error { registerOtlpServer(app, tracesRepo, runRepo, eventEmitter, dataStoreRepo) testPipeline := buildTestPipeline( + pool, pollingProfileRepo, dataStoreRepo, linterRepo, diff --git a/server/app/test_pipeline.go b/server/app/test_pipeline.go index 14e1029a87..8e9eddb40a 100644 --- a/server/app/test_pipeline.go +++ b/server/app/test_pipeline.go @@ -1,6 +1,7 @@ package app import ( + "github.com/jackc/pgx/v5/pgxpool" "github.com/kubeshop/tracetest/server/datastore" "github.com/kubeshop/tracetest/server/executor" "github.com/kubeshop/tracetest/server/executor/pollingprofile" @@ -15,6 +16,7 @@ import ( ) func buildTestPipeline( + pool *pgxpool.Pool, ppRepo *pollingprofile.Repository, dsRepo *datastore.Repository, lintRepo *analyzer.Repository, @@ -86,15 +88,15 @@ func buildTestPipeline( WithTestGetter(testRepo). WithRunGetter(runRepo) + pgQueue := executor.NewPostgresQueueDriver(pool) + pipeline := executor.NewPipeline(queueBuilder, - executor.PipelineStep{Processor: runner, Driver: executor.NewInMemoryQueueDriver("runner")}, - executor.PipelineStep{Processor: tracePoller, Driver: executor.NewInMemoryQueueDriver("tracePoller")}, - executor.PipelineStep{Processor: linterRunner, Driver: executor.NewInMemoryQueueDriver("linterRunner")}, - executor.PipelineStep{Processor: assertionRunner, Driver: executor.NewInMemoryQueueDriver("assertionRunner")}, + executor.PipelineStep{Processor: runner, Driver: pgQueue.Channel("runner")}, + executor.PipelineStep{Processor: tracePoller, Driver: pgQueue.Channel("tracePoller")}, + executor.PipelineStep{Processor: linterRunner, Driver: pgQueue.Channel("linterRunner")}, + executor.PipelineStep{Processor: assertionRunner, Driver: pgQueue.Channel("assertionRunner")}, ) - pipeline.Start() - const assertionRunnerStepIndex = 3 return executor.NewTestPipeline( diff --git a/server/executor/queue.go b/server/executor/queue.go index d9a20dbdfb..c1a8498978 100644 --- a/server/executor/queue.go +++ b/server/executor/queue.go @@ -3,9 +3,9 @@ package executor import ( "context" "database/sql" + "encoding/json" "errors" "fmt" - "log" "strconv" "github.com/kubeshop/tracetest/server/datastore" @@ -18,6 +18,8 @@ import ( ) const ( + QueueWorkerCount = 5 + JobCountHeader string = "X-Tracetest-Job-Count" ) @@ -77,6 +79,46 @@ type Job struct { DataStore datastore.DataStore } +type jsonJob struct { + Headers *headers `json:"headers"` + TransactionID string `json:"transaction_id"` + TransactionRunID int `json:"transaction_run_id"` + TestID string `json:"test_id"` + RunID int `json:"run_id"` + PollingProfileID string `json:"polling_profile_id"` + DataStoreID string `json:"data_store_id"` +} + +func (job Job) MarshalJSON() ([]byte, error) { + return json.Marshal(jsonJob{ + Headers: job.Headers, + TransactionID: job.Transaction.ID.String(), + TransactionRunID: job.TransactionRun.ID, + TestID: job.Test.ID.String(), + RunID: job.Run.ID, + PollingProfileID: job.PollingProfile.ID.String(), + DataStoreID: job.DataStore.ID.String(), + }) +} + +func (job *Job) UnmarshalJSON(data []byte) error { + var jj jsonJob + err := json.Unmarshal(data, &jj) + if err != nil { + return err + } + + job.Headers = jj.Headers + job.Transaction.ID = id.ID(jj.TransactionID) + job.TransactionRun.ID = jj.TransactionRunID + job.Test.ID = id.ID(jj.TestID) + job.Run.ID = jj.RunID + job.PollingProfile.ID = id.ID(jj.PollingProfileID) + job.DataStore.ID = id.ID(jj.DataStoreID) + + return nil +} + func NewJob() Job { return Job{ Headers: &headers{}, @@ -252,7 +294,7 @@ func (q Queue) Enqueue(ctx context.Context, job Job) { Headers: job.Headers, Test: test.Test{ID: job.Test.ID}, - Run: job.Run, + Run: test.Run{ID: job.Run.ID}, Transaction: transaction.Transaction{ID: job.Transaction.ID}, TransactionRun: transaction.TransactionRun{ID: job.TransactionRun.ID}, @@ -277,8 +319,9 @@ func (q Queue) Listen(job Job) { } newJob.Test = q.resolveTest(ctx, job) // todo: revert when using actual queues - // newJob.Run = q.resolveTestRun(ctx, job) - newJob.Run = job.Run + newJob.Run = q.resolveTestRun(ctx, job) + // todo: change the otlp server to have its own table + // newJob.Run = job.Run newJob.Transaction = q.resolveTransaction(ctx, job) newJob.TransactionRun = q.resolveTransactionRun(ctx, job) @@ -427,50 +470,3 @@ func (q Queue) resolveDataStore(ctx context.Context, job Job) datastore.DataStor return ds } - -func NewInMemoryQueueDriver(name string) *InMemoryQueueDriver { - return &InMemoryQueueDriver{ - queue: make(chan Job), - exit: make(chan bool), - name: name, - } -} - -type InMemoryQueueDriver struct { - queue chan Job - exit chan bool - q *Queue - name string -} - -func (r *InMemoryQueueDriver) SetQueue(q *Queue) { - r.q = q -} - -func (r InMemoryQueueDriver) Enqueue(job Job) { - r.queue <- job -} - -const inMemoryQueueWorkerCount = 5 - -func (r InMemoryQueueDriver) Start() { - for i := 0; i < inMemoryQueueWorkerCount; i++ { - go func() { - log.Printf("[InMemoryQueueDriver - %s] start", r.name) - for { - select { - case <-r.exit: - log.Printf("[InMemoryQueueDriver - %s] exit", r.name) - return - case job := <-r.queue: - r.q.Listen(job) - } - } - }() - } -} - -func (r InMemoryQueueDriver) Stop() { - log.Printf("[InMemoryQueueDriver - %s] stopping", r.name) - r.exit <- true -} diff --git a/server/executor/queue_driver_in_memory.go b/server/executor/queue_driver_in_memory.go new file mode 100644 index 0000000000..d13895ca9b --- /dev/null +++ b/server/executor/queue_driver_in_memory.go @@ -0,0 +1,61 @@ +package executor + +import ( + "fmt" + "log" +) + +type loggerFn func(string, ...any) + +func newLoggerFn(name string) loggerFn { + return func(format string, params ...any) { + log.Printf("[%s] %s", name, fmt.Sprintf(format, params...)) + } +} + +func NewInMemoryQueueDriver(name string) *InMemoryQueueDriver { + return &InMemoryQueueDriver{ + log: newLoggerFn(fmt.Sprintf("InMemoryQueueDriver - %s", name)), + queue: make(chan Job), + exit: make(chan bool), + name: name, + } +} + +type InMemoryQueueDriver struct { + log loggerFn + queue chan Job + exit chan bool + q *Queue + name string +} + +func (qd *InMemoryQueueDriver) SetQueue(q *Queue) { + qd.q = q +} + +func (qd InMemoryQueueDriver) Enqueue(job Job) { + qd.queue <- job +} + +func (qd InMemoryQueueDriver) Start() { + for i := 0; i < QueueWorkerCount; i++ { + go func() { + qd.log("start") + for { + select { + case <-qd.exit: + qd.log("exit") + return + case job := <-qd.queue: + qd.q.Listen(job) + } + } + }() + } +} + +func (qd InMemoryQueueDriver) Stop() { + qd.log("stopping") + qd.exit <- true +} diff --git a/server/executor/queue_driver_postgres.go b/server/executor/queue_driver_postgres.go new file mode 100644 index 0000000000..d3ff8e97f1 --- /dev/null +++ b/server/executor/queue_driver_postgres.go @@ -0,0 +1,185 @@ +package executor + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/jackc/pgx/v5/pgxpool" + "github.com/kubeshop/tracetest/server/pkg/id" +) + +func NewPostgresQueueDriver(pool *pgxpool.Pool) *PostgresQueueDriver { + id := id.GenerateID() + return &PostgresQueueDriver{ + log: newLoggerFn("PostgresQueueDriver - " + id.String()), + pool: pool, + channels: map[string]*channel{}, + exit: make(chan bool), + } +} + +// PostgresQueueDriver is a queue driver that uses Postgres LISTEN/NOTIFY +// Since each queue needs its own connection, it's not practical/scalable +// to create a new Driver instance for each queue. Instead, we create a +// single Driver instance and use it to create channels for each queue. +// +// This driver requires one connection that listens to messages in any queue +// and routes them to the correct worker. +type PostgresQueueDriver struct { + log loggerFn + pool *pgxpool.Pool + channels map[string]*channel + running bool + exit chan bool +} + +func (qd *PostgresQueueDriver) getChannel(name string) (*channel, error) { + ch, ok := qd.channels[name] + if !ok { + return nil, fmt.Errorf("channel %s not found", name) + } + + return ch, nil +} + +const pgChannelName = "tracetest_queue" + +type pgJob struct { + Channel string `json:"channel"` + Job Job `json:"job"` +} + +func (qd *PostgresQueueDriver) Start() { + if qd.running { + // we want only 1 worker here + qd.log("already running") + return + } + qd.running = true + + go func(qd *PostgresQueueDriver) { + qd.log("start") + + qd.log("acquiring connection") + conn, err := qd.pool.Acquire(context.Background()) + if err != nil { + panic(fmt.Errorf("error acquiring connection: %w", err)) + } + defer conn.Release() + + for { + select { + case <-qd.exit: + qd.log("exit") + return + default: + qd.worker(conn) + } + } + }(qd) +} + +func (qd *PostgresQueueDriver) worker(conn *pgxpool.Conn) { + qd.log("listening for notifications") + _, err := conn.Exec(context.Background(), "listen "+pgChannelName) + if err != nil { + qd.log("error listening for notifications: %s", err.Error()) + return + } + qd.log("waiting for notification") + notification, err := conn.Conn().WaitForNotification(context.Background()) + if err != nil { + qd.log("error waiting for notification: %s", err.Error()) + return + } + + job := pgJob{} + err = json.Unmarshal([]byte(notification.Payload), &job) + if err != nil { + qd.log("error unmarshalling pgJob: %s", err.Error()) + return + } + + qd.log("received job for channel: %s, runID: %d", job.Channel, job.Job.Run.ID) + + channel, err := qd.getChannel(job.Channel) + if err != nil { + qd.log("error getting channel: %s", err.Error()) + return + } + + // spin off so we can keep listening for jobs + go channel.q.Listen(job.Job) + qd.log("spun off job for channel: %s, runID: %d", job.Channel, job.Job.Run.ID) +} + +func (qd *PostgresQueueDriver) Stop() { + qd.log("stopping") + qd.exit <- true +} + +// Channel registers a new queue channel and returns it +func (qd *PostgresQueueDriver) Channel(name string) *channel { + if _, channelNameExists := qd.channels[name]; channelNameExists { + panic(fmt.Errorf("channel %s already exists", name)) + } + + ch := &channel{ + PostgresQueueDriver: qd, + name: name, + log: newLoggerFn(fmt.Sprintf("PostgresQueueDriver - %s", name)), + pool: qd.pool, + } + + qd.channels[name] = ch + + return ch +} + +type channel struct { + *PostgresQueueDriver + name string + log loggerFn + pool *pgxpool.Pool + q *Queue +} + +func (ch *channel) SetQueue(q *Queue) { + ch.q = q +} + +const enqueueTimeout = 500 * time.Millisecond + +func (ch *channel) Enqueue(job Job) { + ch.log("enqueue") + + jj, err := json.Marshal(pgJob{ + Channel: ch.name, + Job: job, + }) + + if err != nil { + ch.log("error marshalling pgJob: %s", err.Error()) + return + } + + ctx, cancelCtx := context.WithTimeout(context.Background(), enqueueTimeout) + defer cancelCtx() + + conn, err := ch.pool.Acquire(context.Background()) + if err != nil { + ch.log("error acquiring connection: %s", err.Error()) + return + } + defer conn.Release() + + _, err = conn.Query(ctx, fmt.Sprintf(`select pg_notify('%s', $1)`, pgChannelName), jj) + if err != nil { + ch.log("error notifying postgres: %s", err.Error()) + return + } + + ch.log("notified postgres") +} diff --git a/server/executor/runner_test.go b/server/executor/runner_test.go index f16b5d0bbf..e1e8a032c2 100644 --- a/server/executor/runner_test.go +++ b/server/executor/runner_test.go @@ -316,6 +316,10 @@ func (m *runsRepoMock) UpdateRun(_ context.Context, run test.Run) error { } func (r *runsRepoMock) GetRun(_ context.Context, testID id.ID, runID int) (test.Run, error) { + if run, ok := r.runs[testID]; ok && run.ID == runID { + return run, nil + } + args := r.Called(testID, runID) return args.Get(0).(test.Run), args.Error(1) } diff --git a/server/go.mod b/server/go.mod index b30ff4bb22..389c80f3f6 100644 --- a/server/go.mod +++ b/server/go.mod @@ -92,6 +92,7 @@ require ( github.com/inconshreveable/mousetrap v1.0.1 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/puddle/v2 v2.2.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/klauspost/compress v1.15.6 // indirect github.com/knadh/koanf v1.4.0 // indirect diff --git a/server/go.sum b/server/go.sum index 1036029792..9904cae5dc 100644 --- a/server/go.sum +++ b/server/go.sum @@ -1069,6 +1069,8 @@ github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0f github.com/jackc/puddle v1.1.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle/v2 v2.2.0 h1:RdcDk92EJBuBS55nQMMYFXTxwstHug4jkhT5pq8VxPk= +github.com/jackc/puddle/v2 v2.2.0/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4= github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= diff --git a/server/integration/ensure_server_prefix_test.go b/server/integration/ensure_server_prefix_test.go index b17fd5da23..c0bc9aac51 100644 --- a/server/integration/ensure_server_prefix_test.go +++ b/server/integration/ensure_server_prefix_test.go @@ -3,7 +3,7 @@ package integration_test import ( "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "sync" "testing" @@ -61,7 +61,7 @@ func getTests(t *testing.T, endpoint string) resourcemanager.ResourceList[test.T require.NoError(t, err) assert.Equal(t, 200, resp.StatusCode) - bodyJsonBytes, err := ioutil.ReadAll(resp.Body) + bodyJsonBytes, err := io.ReadAll(resp.Body) require.NoError(t, err) var tests resourcemanager.ResourceList[test.Test] @@ -77,7 +77,7 @@ func getDatastores(t *testing.T, endpoint string) resourcemanager.ResourceList[d require.NoError(t, err) assert.Equal(t, 200, resp.StatusCode) - bodyJsonBytes, err := ioutil.ReadAll(resp.Body) + bodyJsonBytes, err := io.ReadAll(resp.Body) require.NoError(t, err) var dataStores resourcemanager.ResourceList[datastore.DataStore]