diff --git a/internal/api/status.go b/internal/api/api.go similarity index 72% rename from internal/api/status.go rename to internal/api/api.go index 3cfc940a..1119619c 100644 --- a/internal/api/status.go +++ b/internal/api/api.go @@ -1,6 +1,7 @@ package api import ( + "context" "fmt" "net/http" "time" @@ -9,14 +10,16 @@ import ( "github.com/cybertec-postgresql/pg_timetable/internal/log" ) -// StatusReporter is a common interface describing the current status of a connection -type StatusReporter interface { +// RestHandler is a common interface describing the current status of a connection +type RestHandler interface { IsReady() bool + StartChain(context.Context, int) error + StopChain(context.Context, int) error } type RestApiServer struct { - Reporter StatusReporter - l log.LoggerIface + ApiHandler RestHandler + l log.LoggerIface http.Server } @@ -35,6 +38,8 @@ func Init(opts config.RestApiOpts, logger log.LoggerIface) *RestApiServer { w.WriteHeader(http.StatusOK) // i'm serving hence I'm alive }) http.HandleFunc("/readiness", s.readinessHandler) + http.HandleFunc("/startchain", s.chainHandler) + http.HandleFunc("/stopchain", s.chainHandler) if opts.Port != 0 { logger.WithField("port", opts.Port).Info("Starting REST API server...") go func() { logger.Error(s.ListenAndServe()) }() @@ -44,9 +49,10 @@ func Init(opts config.RestApiOpts, logger log.LoggerIface) *RestApiServer { func (Server *RestApiServer) readinessHandler(w http.ResponseWriter, r *http.Request) { Server.l.Debug("Received /readiness REST API request") - if Server.Reporter == nil || !Server.Reporter.IsReady() { + if Server.ApiHandler == nil || !Server.ApiHandler.IsReady() { w.WriteHeader(http.StatusServiceUnavailable) return } w.WriteHeader(http.StatusOK) + r.Context() } diff --git a/internal/api/api_test.go b/internal/api/api_test.go new file mode 100644 index 00000000..f5cdf100 --- /dev/null +++ b/internal/api/api_test.go @@ -0,0 +1,77 @@ +package api_test + +import ( + "context" + "errors" + "io" + "net/http" + "testing" + + "github.com/cybertec-postgresql/pg_timetable/internal/api" + "github.com/cybertec-postgresql/pg_timetable/internal/config" + "github.com/cybertec-postgresql/pg_timetable/internal/log" + "github.com/stretchr/testify/assert" +) + +type apihandler struct { +} + +func (r *apihandler) IsReady() bool { + return true +} + +func (sch *apihandler) StartChain(ctx context.Context, chainId int) error { + if chainId == 0 { + return errors.New("invalid chain id") + } + return nil +} + +func (sch *apihandler) StopChain(ctx context.Context, chainId int) error { + return nil +} + +var restsrv *api.RestApiServer + +func init() { + restsrv = api.Init(config.RestApiOpts{Port: 8080}, log.Init(config.LoggingOpts{LogLevel: "error"})) +} + +func TestStatus(t *testing.T) { + + r, err := http.Get("http://localhost:8080/liveness") + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, r.StatusCode) + + r, err = http.Get("http://localhost:8080/readiness") + assert.NoError(t, err) + assert.Equal(t, http.StatusServiceUnavailable, r.StatusCode) + + restsrv.ApiHandler = &apihandler{} + r, err = http.Get("http://localhost:8080/readiness") + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, r.StatusCode) +} + +func TestChainManager(t *testing.T) { + restsrv.ApiHandler = &apihandler{} + r, err := http.Get("http://localhost:8080/startchain") + assert.NoError(t, err) + assert.Equal(t, http.StatusBadRequest, r.StatusCode) + b, _ := io.ReadAll(r.Body) + assert.Contains(t, string(b), "invalid syntax") + + r, err = http.Get("http://localhost:8080/startchain?id=1") + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, r.StatusCode) + + r, err = http.Get("http://localhost:8080/stopchain?id=1") + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, r.StatusCode) + + r, err = http.Get("http://localhost:8080/startchain?id=0") + assert.NoError(t, err) + assert.Equal(t, http.StatusBadRequest, r.StatusCode) + b, _ = io.ReadAll(r.Body) + assert.Contains(t, string(b), "invalid chain id") +} diff --git a/internal/api/chainapi.go b/internal/api/chainapi.go new file mode 100644 index 00000000..df6a8195 --- /dev/null +++ b/internal/api/chainapi.go @@ -0,0 +1,26 @@ +package api + +import ( + "net/http" + "strconv" +) + +func (Server *RestApiServer) chainHandler(w http.ResponseWriter, r *http.Request) { + Server.l.Debug("Received chain REST API request") + chainID, err := strconv.Atoi(r.URL.Query().Get("id")) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + switch r.URL.Path { + case "/startchain": + err = Server.ApiHandler.StartChain(r.Context(), chainID) + case "/stopchain": + err = Server.ApiHandler.StopChain(r.Context(), chainID) + } + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusOK) +} diff --git a/internal/api/status_test.go b/internal/api/status_test.go deleted file mode 100644 index 312c3bbc..00000000 --- a/internal/api/status_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package api_test - -import ( - "net/http" - "testing" - - "github.com/cybertec-postgresql/pg_timetable/internal/api" - "github.com/cybertec-postgresql/pg_timetable/internal/config" - "github.com/cybertec-postgresql/pg_timetable/internal/log" - "github.com/stretchr/testify/assert" -) - -type reporter struct { -} - -func (r *reporter) IsReady() bool { - return true -} - -func TestStatus(t *testing.T) { - restsrv := api.Init(config.RestApiOpts{Port: 8080}, log.Init(config.LoggingOpts{LogLevel: "error"})) - r, err := http.Get("http://localhost:8080/liveness") - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, r.StatusCode) - - r, err = http.Get("http://localhost:8080/readiness") - assert.NoError(t, err) - assert.Equal(t, http.StatusServiceUnavailable, r.StatusCode) - - restsrv.Reporter = &reporter{} - r, err = http.Get("http://localhost:8080/readiness") - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, r.StatusCode) -} diff --git a/internal/pgengine/access.go b/internal/pgengine/access.go index 54ec5b30..00bfd104 100644 --- a/internal/pgengine/access.go +++ b/internal/pgengine/access.go @@ -107,9 +107,14 @@ FROM timetable.chain WHERE live AND (client_name = $1 or client_name IS NULL) AN } // SelectChain returns the chain with the specified ID -func (pge *PgEngine) SelectChain(ctx context.Context, dest interface{}, chainID int) error { +func (pge *PgEngine) SelectChain(ctx context.Context, dest *Chain, chainID int) error { // we accept not only live chains here because we want to run them in debug mode const sqlSelectSingleChain = `SELECT chain_id, chain_name, self_destruct, exclusive_execution, COALESCE(timeout, 0) as timeout, COALESCE(max_instances, 16) as max_instances FROM timetable.chain WHERE (client_name = $1 OR client_name IS NULL) AND chain_id = $2` - return pge.ConfigDb.QueryRow(ctx, sqlSelectSingleChain, pge.ClientName, chainID).Scan(dest) + rows, err := pge.ConfigDb.Query(ctx, sqlSelectSingleChain, pge.ClientName, chainID) + if err != nil { + return err + } + *dest, err = pgx.CollectOneRow(rows, RowToStructByName[Chain]) + return err } diff --git a/internal/pgengine/access_test.go b/internal/pgengine/access_test.go index b341b629..c68538fa 100644 --- a/internal/pgengine/access_test.go +++ b/internal/pgengine/access_test.go @@ -83,7 +83,7 @@ func TestSelectChain(t *testing.T) { defer mockPool.Close() mockPool.ExpectExec("SELECT.+chain_id").WillReturnError(errors.New("error")) - assert.Error(t, pge.SelectChain(context.Background(), struct{}{}, 42)) + assert.Error(t, pge.SelectChain(context.Background(), &pgengine.Chain{}, 42)) } func TestIsAlive(t *testing.T) { diff --git a/internal/scheduler/chain.go b/internal/scheduler/chain.go index d5415925..3b0e85bd 100644 --- a/internal/scheduler/chain.go +++ b/internal/scheduler/chain.go @@ -2,6 +2,7 @@ package scheduler import ( "context" + "fmt" "strings" "time" @@ -10,7 +11,10 @@ import ( pgx "github.com/jackc/pgx/v5" ) -type Chain = pgengine.Chain +type ( + Chain = pgengine.Chain + ChainSignal = pgengine.ChainSignal +) // SendChain sends chain to the channel for workers func (sch *Scheduler) SendChain(c Chain) { @@ -46,21 +50,29 @@ func (sch *Scheduler) retrieveAsyncChainsAndRun(ctx context.Context) { if chainSignal.ConfigID == 0 { return } + err := sch.processAsyncChain(ctx, chainSignal) + if err != nil { + sch.l.WithError(err).Error("Could not process async chain command") + } + } +} + +func (sch *Scheduler) processAsyncChain(ctx context.Context, chainSignal ChainSignal) error { + switch chainSignal.Command { + case "START": var c Chain - switch chainSignal.Command { - case "START": - err := sch.pgengine.SelectChain(ctx, &c, chainSignal.ConfigID) - if err != nil { - sch.l.WithError(err).Error("Could not query pending tasks") - } else { - sch.SendChain(c) - } - case "STOP": - if cancel, ok := sch.activeChains[chainSignal.ConfigID]; ok { - cancel() - } + if err := sch.pgengine.SelectChain(ctx, &c, chainSignal.ConfigID); err != nil { + return err + } + sch.SendChain(c) + case "STOP": + if cancel, ok := sch.activeChains[chainSignal.ConfigID]; ok { + cancel() + return nil } + return fmt.Errorf("Cannot stop chain with ID: %d. No running chain found", chainSignal.ConfigID) } + return nil } func (sch *Scheduler) retrieveChainsAndRun(ctx context.Context, reboot bool) { diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go index 97881c22..af9b51ec 100644 --- a/internal/scheduler/scheduler.go +++ b/internal/scheduler/scheduler.go @@ -84,6 +84,20 @@ func (sch *Scheduler) IsReady() bool { return sch.status == RunningStatus } +func (sch *Scheduler) StartChain(ctx context.Context, chainId int) error { + return sch.processAsyncChain(ctx, ChainSignal{ + ConfigID: chainId, + Command: "START", + Ts: time.Now().Unix()}) +} + +func (sch *Scheduler) StopChain(ctx context.Context, chainId int) error { + return sch.processAsyncChain(ctx, ChainSignal{ + ConfigID: chainId, + Command: "STOP", + Ts: time.Now().Unix()}) +} + // Run executes jobs. Returns RunStatus why it terminated. // There are only two possibilities: dropped connection and cancelled context. func (sch *Scheduler) Run(ctx context.Context) RunStatus { diff --git a/internal/scheduler/scheduler_test.go b/internal/scheduler/scheduler_test.go index 75105f0e..ed77c446 100644 --- a/internal/scheduler/scheduler_test.go +++ b/internal/scheduler/scheduler_test.go @@ -15,7 +15,7 @@ import ( var pge *pgengine.PgEngine -//SetupTestCase used to connect and to initialize test PostgreSQL database +// SetupTestCase used to connect and to initialize test PostgreSQL database func SetupTestCase(t *testing.T) func(t *testing.T) { cmdOpts := config.NewCmdOptions("-c", "pgengine_unit_test", "--password=somestrong") t.Log("Setup test case") @@ -53,6 +53,8 @@ func TestRun(t *testing.T) { err = pge.ExecuteCustomScripts(context.Background(), "../../samples/ManyTasks.sql") assert.NoError(t, err, "Creating many tasks failed") sch := New(pge, log.Init(config.LoggingOpts{LogLevel: "error"})) + assert.NoError(t, sch.StartChain(context.Background(), 1)) + assert.ErrorContains(t, sch.StopChain(context.Background(), -1), "No running chain found") go func() { time.Sleep(10 * time.Second) sch.Shutdown() diff --git a/main.go b/main.go index c70aebaf..427927f3 100644 --- a/main.go +++ b/main.go @@ -118,7 +118,7 @@ func main() { return } sch := scheduler.New(pge, logger) - apiserver.Reporter = sch + apiserver.ApiHandler = sch if sch.Run(ctx) == scheduler.ShutdownStatus { exitCode = ExitCodeShutdownCommand