Skip to content

Commit

Permalink
[+] add /startchain and /stopchain REST API endpoints, closes #482
Browse files Browse the repository at this point in the history
  • Loading branch information
pashagolub committed Oct 13, 2022
1 parent b22f2bd commit 4f8efd9
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 53 deletions.
16 changes: 11 additions & 5 deletions internal/api/status.go → internal/api/api.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package api

import (
"context"
"fmt"
"net/http"
"time"
Expand All @@ -9,14 +10,16 @@ import (
"github.com/cybertec-postgresql/pg_timetable/internal/log"
)

// StatusReporter is a common interface describing the current status of a connection
type StatusReporter interface {
// RestHandler is a common interface describing the current status of a connection
type RestHandler interface {
IsReady() bool
StartChain(context.Context, int) error
StopChain(context.Context, int) error
}

type RestApiServer struct {
Reporter StatusReporter
l log.LoggerIface
ApiHandler RestHandler
l log.LoggerIface
http.Server
}

Expand All @@ -35,6 +38,8 @@ func Init(opts config.RestApiOpts, logger log.LoggerIface) *RestApiServer {
w.WriteHeader(http.StatusOK) // i'm serving hence I'm alive
})
http.HandleFunc("/readiness", s.readinessHandler)
http.HandleFunc("/startchain", s.chainHandler)
http.HandleFunc("/stopchain", s.chainHandler)
if opts.Port != 0 {
logger.WithField("port", opts.Port).Info("Starting REST API server...")
go func() { logger.Error(s.ListenAndServe()) }()
Expand All @@ -44,9 +49,10 @@ func Init(opts config.RestApiOpts, logger log.LoggerIface) *RestApiServer {

func (Server *RestApiServer) readinessHandler(w http.ResponseWriter, r *http.Request) {
Server.l.Debug("Received /readiness REST API request")
if Server.Reporter == nil || !Server.Reporter.IsReady() {
if Server.ApiHandler == nil || !Server.ApiHandler.IsReady() {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
r.Context()
}
77 changes: 77 additions & 0 deletions internal/api/api_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package api_test

import (
"context"
"errors"
"io"
"net/http"
"testing"

"github.com/cybertec-postgresql/pg_timetable/internal/api"
"github.com/cybertec-postgresql/pg_timetable/internal/config"
"github.com/cybertec-postgresql/pg_timetable/internal/log"
"github.com/stretchr/testify/assert"
)

type apihandler struct {
}

func (r *apihandler) IsReady() bool {
return true
}

func (sch *apihandler) StartChain(ctx context.Context, chainId int) error {
if chainId == 0 {
return errors.New("invalid chain id")
}
return nil
}

func (sch *apihandler) StopChain(ctx context.Context, chainId int) error {
return nil
}

var restsrv *api.RestApiServer

func init() {
restsrv = api.Init(config.RestApiOpts{Port: 8080}, log.Init(config.LoggingOpts{LogLevel: "error"}))
}

func TestStatus(t *testing.T) {

r, err := http.Get("http://localhost:8080/liveness")
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, r.StatusCode)

r, err = http.Get("http://localhost:8080/readiness")
assert.NoError(t, err)
assert.Equal(t, http.StatusServiceUnavailable, r.StatusCode)

restsrv.ApiHandler = &apihandler{}
r, err = http.Get("http://localhost:8080/readiness")
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, r.StatusCode)
}

func TestChainManager(t *testing.T) {
restsrv.ApiHandler = &apihandler{}
r, err := http.Get("http://localhost:8080/startchain")
assert.NoError(t, err)
assert.Equal(t, http.StatusBadRequest, r.StatusCode)
b, _ := io.ReadAll(r.Body)
assert.Contains(t, string(b), "invalid syntax")

r, err = http.Get("http://localhost:8080/startchain?id=1")
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, r.StatusCode)

r, err = http.Get("http://localhost:8080/stopchain?id=1")
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, r.StatusCode)

r, err = http.Get("http://localhost:8080/startchain?id=0")
assert.NoError(t, err)
assert.Equal(t, http.StatusBadRequest, r.StatusCode)
b, _ = io.ReadAll(r.Body)
assert.Contains(t, string(b), "invalid chain id")
}
26 changes: 26 additions & 0 deletions internal/api/chainapi.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package api

import (
"net/http"
"strconv"
)

func (Server *RestApiServer) chainHandler(w http.ResponseWriter, r *http.Request) {
Server.l.Debugf("Received /%s REST API request", r.URL.Path)
chainID, err := strconv.Atoi(r.URL.Query().Get("id"))
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
switch r.URL.Path {
case "/startchain":
err = Server.ApiHandler.StartChain(r.Context(), chainID)
case "/stopchain":
err = Server.ApiHandler.StopChain(r.Context(), chainID)
}
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
}
34 changes: 0 additions & 34 deletions internal/api/status_test.go

This file was deleted.

38 changes: 25 additions & 13 deletions internal/scheduler/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package scheduler

import (
"context"
"fmt"
"strings"
"time"

Expand All @@ -10,7 +11,10 @@ import (
pgx "github.com/jackc/pgx/v5"
)

type Chain = pgengine.Chain
type (
Chain = pgengine.Chain
ChainSignal = pgengine.ChainSignal
)

// SendChain sends chain to the channel for workers
func (sch *Scheduler) SendChain(c Chain) {
Expand Down Expand Up @@ -46,21 +50,29 @@ func (sch *Scheduler) retrieveAsyncChainsAndRun(ctx context.Context) {
if chainSignal.ConfigID == 0 {
return
}
err := sch.processAsyncChain(ctx, chainSignal)
if err != nil {
sch.l.WithError(err).Error("Could not process async chain command")
}
}
}

func (sch *Scheduler) processAsyncChain(ctx context.Context, chainSignal ChainSignal) error {
switch chainSignal.Command {
case "START":
var c Chain
switch chainSignal.Command {
case "START":
err := sch.pgengine.SelectChain(ctx, &c, chainSignal.ConfigID)
if err != nil {
sch.l.WithError(err).Error("Could not query pending tasks")
} else {
sch.SendChain(c)
}
case "STOP":
if cancel, ok := sch.activeChains[chainSignal.ConfigID]; ok {
cancel()
}
if err := sch.pgengine.SelectChain(ctx, &c, chainSignal.ConfigID); err != nil {
return err
}
sch.SendChain(c)
case "STOP":
if cancel, ok := sch.activeChains[chainSignal.ConfigID]; ok {
cancel()
return nil
}
return fmt.Errorf("Cannot stop chain with ID: %d. No running chain found", chainSignal.ConfigID)
}
return nil
}

func (sch *Scheduler) retrieveChainsAndRun(ctx context.Context, reboot bool) {
Expand Down
14 changes: 14 additions & 0 deletions internal/scheduler/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,20 @@ func (sch *Scheduler) IsReady() bool {
return sch.status == RunningStatus
}

func (sch *Scheduler) StartChain(ctx context.Context, chainId int) error {
return sch.processAsyncChain(ctx, ChainSignal{
ConfigID: chainId,
Command: "START",
Ts: time.Now().Unix()})
}

func (sch *Scheduler) StopChain(ctx context.Context, chainId int) error {
return sch.processAsyncChain(ctx, ChainSignal{
ConfigID: chainId,
Command: "STOP",
Ts: time.Now().Unix()})
}

// Run executes jobs. Returns RunStatus why it terminated.
// There are only two possibilities: dropped connection and cancelled context.
func (sch *Scheduler) Run(ctx context.Context) RunStatus {
Expand Down
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func main() {
return
}
sch := scheduler.New(pge, logger)
apiserver.Reporter = sch
apiserver.ApiHandler = sch

if sch.Run(ctx) == scheduler.ShutdownStatus {
exitCode = ExitCodeShutdownCommand
Expand Down

0 comments on commit 4f8efd9

Please sign in to comment.