From eec83abf84734918403465e44ba551cf3f0262c4 Mon Sep 17 00:00:00 2001 From: Henri Koski Date: Sun, 17 Mar 2024 20:35:11 +0200 Subject: [PATCH] Add broadcaster logic (#6) * Add broadcaster logic * Update db schema * Update tests * Fix tests --- api/api.go | 55 +++------- api/broadcaster.go | 145 +++++++++++++++++++++++++++ api/server.go | 4 +- api/templates/index.templ | 11 +- api/templates/index_templ.go | 43 ++++++-- app/app.go | 5 +- go.mod | 2 + go.sum | 5 + integrationtests/integration_test.go | 106 ++++++++++++++++---- integrationtests/utils_test.go | 3 +- signalhandler/signal_handler.go | 6 +- store/mock/store.go | 3 +- store/pg/store.go | 58 ++++++++--- testdata/tables.sql | 3 +- 14 files changed, 354 insertions(+), 95 deletions(-) create mode 100644 api/broadcaster.go diff --git a/api/api.go b/api/api.go index 15e9a7a..c5c42c0 100644 --- a/api/api.go +++ b/api/api.go @@ -1,9 +1,7 @@ package api import ( - "bytes" "embed" - "encoding/json" "io/fs" "log/slog" "net/http" @@ -24,12 +22,14 @@ type Store interface { type API struct { r *http.ServeMux s Store + b *Broadcaster } -func New(s Store) (*API, error) { +func New(s Store, b *Broadcaster) (*API, error) { a := &API{ r: http.NewServeMux(), s: s, + b: b, } assetsFS, err := fs.Sub(assets, "assets") @@ -38,8 +38,7 @@ func New(s Store) (*API, error) { } a.r.Handle("GET /", templ.Handler(templates.Index())) - a.r.HandleFunc("GET /api/v1/stats", a.stats) - a.r.HandleFunc("GET /api/v1/events", a.events) + a.r.HandleFunc("GET /api/v1/stream", a.stream) a.r.Handle("GET /assets/", http.StripPrefix("/assets/", http.FileServer(http.FS(assetsFS)))) return a, nil @@ -49,56 +48,28 @@ func (a *API) ServeHTTP(w http.ResponseWriter, r *http.Request) { a.r.ServeHTTP(w, r) } -func (a *API) stats(w http.ResponseWriter, r *http.Request) { - stats, err := a.s.Stats() - if err != nil { - slog.Error("failed to get stats", slog.Any("error", err)) - http.Error(w, "failed to get stats", http.StatusInternalServerError) - return - } - - if r.Header.Get("Accept") == "application/json" { - if err := json.NewEncoder(w).Encode(stats); err != nil { - slog.Error("failed encode stats", slog.Any("error", err)) - return - } - return - } - - if err := templates.Stats(stats).Render(r.Context(), w); err != nil { - slog.Error("failed render stats", slog.Any("error", err)) - return - } -} - -func (a *API) events(w http.ResponseWriter, r *http.Request) { +func (a *API) stream(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") slog.Info("client connected", slog.String("remote_addr", r.RemoteAddr)) - + ch, unsubscribe := a.b.subscribe() + defer unsubscribe() for { select { case <-r.Context().Done(): // Client disconnected slog.Info("client disconnected", slog.String("remote_addr", r.RemoteAddr)) return - case event, ok := <-a.s.Events(): - if !ok { - return - } - - buf := &bytes.Buffer{} - buf.WriteString("data: ") - if err := templates.Row(event).Render(r.Context(), buf); err != nil { - slog.Error("failed render row", slog.Any("error", err)) + case data := <-ch: + if _, err := w.Write(data); err != nil { + slog.Error("failed to write to client, closing connection", slog.String("remote_addr", r.RemoteAddr), slog.Any("err", err)) return } - buf.WriteString("\n\n") - if _, err := buf.WriteTo(w); err != nil { - slog.Error("failed send event", slog.Any("error", err)) - } w.(http.Flusher).Flush() + case <-a.b.done: + slog.Info("broadcaster stopped, closing connection", slog.String("remote_addr", r.RemoteAddr)) + return } } } diff --git a/api/broadcaster.go b/api/broadcaster.go new file mode 100644 index 0000000..992eae2 --- /dev/null +++ b/api/broadcaster.go @@ -0,0 +1,145 @@ +package api + +import ( + "bytes" + "context" + "fmt" + "io" + "log/slog" + "sync" + "time" + + "github.com/a-h/templ" + "github.com/gevulotnetwork/devnet-explorer/api/templates" + "github.com/gevulotnetwork/devnet-explorer/model" +) + +type Broadcaster struct { + s Store + clientsMu sync.Mutex + nextID uint64 + clients map[uint64]chan<- []byte + headIndex uint8 + head [50][]byte + + done chan struct{} +} + +func NewBroadcaster(s Store) *Broadcaster { + return &Broadcaster{ + s: s, + clients: make(map[uint64]chan<- []byte), + done: make(chan struct{}), + } +} + +func (b *Broadcaster) subscribe() (data <-chan []byte, unsubscribe func()) { + b.clientsMu.Lock() + defer b.clientsMu.Unlock() + + id := b.nextID + ch := make(chan []byte, len(b.head)+2) + b.clients[id] = ch + b.nextID++ + slog.Info("client subscribed", slog.Uint64("id", id)) + + for i := 1; i <= len(b.head); i++ { + idx := (b.headIndex + uint8(i)) % uint8(len(b.head)) + if b.head[idx] != nil { + ch <- b.head[idx] + } + } + + return ch, func() { + slog.Info("client unsubscribed", slog.Uint64("id", id)) + b.clientsMu.Lock() + defer b.clientsMu.Unlock() + delete(b.clients, id) + close(ch) + } +} + +func (b *Broadcaster) Run() error { + t := time.NewTicker(time.Second * 2) + for { + var ev EventComponent + select { + case event, ok := <-b.s.Events(): + if !ok { + slog.Info("store.Events() channel closed, broadcasting stopped") + return nil + } + slog.Debug("new tx event received") + ev = TXRowEvent(event) + case <-t.C: + stats, err := b.s.Stats() + if err != nil { + return fmt.Errorf("failed to get stats: %w", err) + } + slog.Debug("stats updated") + ev = StatEvent(stats) + case <-b.done: + return nil + } + + buf := &bytes.Buffer{} + if err := writeEvent(buf, ev); err != nil { + slog.Error("failed write event into buffer", slog.Any("error", err)) + continue + } + + b.broadcast(buf.Bytes()) + } +} + +func (b *Broadcaster) broadcast(data []byte) { + b.clientsMu.Lock() + defer b.clientsMu.Unlock() + b.head[b.headIndex] = data + b.headIndex = (b.headIndex + 1) % uint8(len(b.head)) + for id, c := range b.clients { + select { + case c <- data: + slog.Debug("data broadcasted", slog.Uint64("id", id)) + default: + slog.Info("client blocked, broadcasting event skipped", slog.Uint64("id", id)) + } + } +} + +func (b *Broadcaster) Stop() error { + close(b.done) + return nil +} + +type EventComponent struct { + templ.Component + name string +} + +func (e EventComponent) Name() string { + return e.name +} + +func TXRowEvent(e model.Event) EventComponent { + return EventComponent{ + Component: templates.Row(e), + name: templates.EventTXRow, + } +} + +func StatEvent(s model.Stats) EventComponent { + return EventComponent{ + Component: templates.Stats(s), + name: templates.EventStats, + } +} + +func writeEvent(w io.Writer, c EventComponent) error { + fmt.Fprintf(w, "event: %s\ndata: ", c.Name()) + if err := c.Render(context.Background(), w); err != nil { + return fmt.Errorf("failed render html: %w", err) + } + fmt.Fprint(w, "\n\n") + return nil +} diff --git a/api/server.go b/api/server.go index d3112fa..f2f5225 100644 --- a/api/server.go +++ b/api/server.go @@ -12,8 +12,8 @@ type Server struct { srv *http.Server } -func NewServer(addr string, s Store) (*Server, error) { - a, err := New(s) +func NewServer(addr string, s Store, b *Broadcaster) (*Server, error) { + a, err := New(s, b) if err != nil { return nil, fmt.Errorf("failed to create api: %w", err) } diff --git a/api/templates/index.templ b/api/templates/index.templ index bafceb5..43732bb 100644 --- a/api/templates/index.templ +++ b/api/templates/index.templ @@ -3,12 +3,17 @@ package templates import "github.com/gevulotnetwork/devnet-explorer/model" import "strconv" +const ( + EventTXRow = "tx-row" + EventStats = "stats" +) + templ Index() { @head() -
+
@header() @Stats(model.Stats{}) @Table(nil) @@ -19,7 +24,7 @@ templ Index() { } templ Stats(stats model.Stats) { -
+
{ strconv.Itoa(int(stats.RegisteredUsers)) }
Registered
Users
@@ -52,7 +57,7 @@ templ Table(events []model.Event) {
-
+
for _, e := range events { @Row(e) } diff --git a/api/templates/index_templ.go b/api/templates/index_templ.go index 17003de..832711c 100644 --- a/api/templates/index_templ.go +++ b/api/templates/index_templ.go @@ -13,6 +13,11 @@ import "bytes" import "github.com/gevulotnetwork/devnet-explorer/model" import "strconv" +const ( + EventTXRow = "tx-row" + EventStats = "stats" +) + func Index() templ.Component { return templ.ComponentFunc(func(ctx context.Context, templ_7745c5c3_W io.Writer) (templ_7745c5c3_Err error) { templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templ_7745c5c3_W.(*bytes.Buffer) @@ -34,7 +39,7 @@ func Index() templ.Component { if templ_7745c5c3_Err != nil { return templ_7745c5c3_Err } - _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("
") + _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("
") if templ_7745c5c3_Err != nil { return templ_7745c5c3_Err } @@ -78,14 +83,22 @@ func Stats(stats model.Stats) templ.Component { templ_7745c5c3_Var2 = templ.NopComponent } ctx = templ.ClearChildren(ctx) - _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("
") + _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("
") if templ_7745c5c3_Err != nil { return templ_7745c5c3_Err } var templ_7745c5c3_Var3 string templ_7745c5c3_Var3, templ_7745c5c3_Err = templ.JoinStringErrs(strconv.Itoa(int(stats.RegisteredUsers))) if templ_7745c5c3_Err != nil { - return templ.Error{Err: templ_7745c5c3_Err, FileName: `api/templates/index.templ`, Line: 23, Col: 95} + return templ.Error{Err: templ_7745c5c3_Err, FileName: `api/templates/index.templ`, Line: 28, Col: 95} } _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var3)) if templ_7745c5c3_Err != nil { @@ -98,7 +111,7 @@ func Stats(stats model.Stats) templ.Component { var templ_7745c5c3_Var4 string templ_7745c5c3_Var4, templ_7745c5c3_Err = templ.JoinStringErrs(strconv.Itoa(int(stats.ProversDeployed))) if templ_7745c5c3_Err != nil { - return templ.Error{Err: templ_7745c5c3_Err, FileName: `api/templates/index.templ`, Line: 27, Col: 95} + return templ.Error{Err: templ_7745c5c3_Err, FileName: `api/templates/index.templ`, Line: 32, Col: 95} } _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var4)) if templ_7745c5c3_Err != nil { @@ -111,7 +124,7 @@ func Stats(stats model.Stats) templ.Component { var templ_7745c5c3_Var5 string templ_7745c5c3_Var5, templ_7745c5c3_Err = templ.JoinStringErrs(strconv.Itoa(int(stats.ProofsGenerated))) if templ_7745c5c3_Err != nil { - return templ.Error{Err: templ_7745c5c3_Err, FileName: `api/templates/index.templ`, Line: 31, Col: 95} + return templ.Error{Err: templ_7745c5c3_Err, FileName: `api/templates/index.templ`, Line: 36, Col: 95} } _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var5)) if templ_7745c5c3_Err != nil { @@ -124,7 +137,7 @@ func Stats(stats model.Stats) templ.Component { var templ_7745c5c3_Var6 string templ_7745c5c3_Var6, templ_7745c5c3_Err = templ.JoinStringErrs(strconv.Itoa(int(stats.ProofsVerified))) if templ_7745c5c3_Err != nil { - return templ.Error{Err: templ_7745c5c3_Err, FileName: `api/templates/index.templ`, Line: 35, Col: 93} + return templ.Error{Err: templ_7745c5c3_Err, FileName: `api/templates/index.templ`, Line: 40, Col: 93} } _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var6)) if templ_7745c5c3_Err != nil { @@ -154,7 +167,15 @@ func Table(events []model.Event) templ.Component { templ_7745c5c3_Var7 = templ.NopComponent } ctx = templ.ClearChildren(ctx) - _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("
State
Transaction ID
Prover ID
Time
") + _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString("
State
Transaction ID
Prover ID
Time
") if templ_7745c5c3_Err != nil { return templ_7745c5c3_Err } @@ -212,7 +233,7 @@ func Row(e model.Event) templ.Component { var templ_7745c5c3_Var10 string templ_7745c5c3_Var10, templ_7745c5c3_Err = templ.JoinStringErrs(e.State) if templ_7745c5c3_Err != nil { - return templ.Error{Err: templ_7745c5c3_Err, FileName: `api/templates/index.templ`, Line: 63, Col: 91} + return templ.Error{Err: templ_7745c5c3_Err, FileName: `api/templates/index.templ`, Line: 68, Col: 91} } _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var10)) if templ_7745c5c3_Err != nil { @@ -225,7 +246,7 @@ func Row(e model.Event) templ.Component { var templ_7745c5c3_Var11 string templ_7745c5c3_Var11, templ_7745c5c3_Err = templ.JoinStringErrs(e.TxID) if templ_7745c5c3_Err != nil { - return templ.Error{Err: templ_7745c5c3_Err, FileName: `api/templates/index.templ`, Line: 63, Col: 130} + return templ.Error{Err: templ_7745c5c3_Err, FileName: `api/templates/index.templ`, Line: 68, Col: 130} } _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var11)) if templ_7745c5c3_Err != nil { @@ -238,7 +259,7 @@ func Row(e model.Event) templ.Component { var templ_7745c5c3_Var12 string templ_7745c5c3_Var12, templ_7745c5c3_Err = templ.JoinStringErrs(e.ProverID) if templ_7745c5c3_Err != nil { - return templ.Error{Err: templ_7745c5c3_Err, FileName: `api/templates/index.templ`, Line: 63, Col: 191} + return templ.Error{Err: templ_7745c5c3_Err, FileName: `api/templates/index.templ`, Line: 68, Col: 191} } _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var12)) if templ_7745c5c3_Err != nil { @@ -251,7 +272,7 @@ func Row(e model.Event) templ.Component { var templ_7745c5c3_Var13 string templ_7745c5c3_Var13, templ_7745c5c3_Err = templ.JoinStringErrs(e.Timestamp.Format("03:04 PM, 02/01/06")) if templ_7745c5c3_Err != nil { - return templ.Error{Err: templ_7745c5c3_Err, FileName: `api/templates/index.templ`, Line: 63, Col: 280} + return templ.Error{Err: templ_7745c5c3_Err, FileName: `api/templates/index.templ`, Line: 68, Col: 280} } _, templ_7745c5c3_Err = templ_7745c5c3_Buffer.WriteString(templ.EscapeString(templ_7745c5c3_Var13)) if templ_7745c5c3_Err != nil { diff --git a/app/app.go b/app/app.go index e837639..972d250 100644 --- a/app/app.go +++ b/app/app.go @@ -34,13 +34,14 @@ func Run(args ...string) error { } } - srv, err := api.NewServer(conf.ServerListenAddr, s) + brc := api.NewBroadcaster(s) + srv, err := api.NewServer(conf.ServerListenAddr, s, brc) if err != nil { return fmt.Errorf("failed to api server: %w", err) } sh := signalhandler.New(os.Interrupt) - r := NewRunner(s, srv, sh) + r := NewRunner(s, srv, brc, sh) return r.Run() } diff --git a/go.mod b/go.mod index 286b5c6..cb107cf 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/jackc/pgx/v5 v5.4.3 github.com/magefile/mage v1.14.0 + github.com/r3labs/sse v0.0.0-20210224172625-26fe804710bc github.com/stretchr/testify v1.8.4 github.com/testcontainers/testcontainers-go/modules/compose v0.28.0 ) @@ -327,6 +328,7 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20231120223509-83a465c0220f // indirect google.golang.org/grpc v1.59.0 // indirect google.golang.org/protobuf v1.31.0 // indirect + gopkg.in/cenkalti/backoff.v1 v1.1.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect diff --git a/go.sum b/go.sum index 6713046..5367372 100644 --- a/go.sum +++ b/go.sum @@ -776,6 +776,8 @@ github.com/quasilyte/regex/syntax v0.0.0-20210819130434-b3f0c404a727 h1:TCg2WBOl github.com/quasilyte/regex/syntax v0.0.0-20210819130434-b3f0c404a727/go.mod h1:rlzQ04UMyJXu/aOvhd8qT+hvDrFpiwqp8MRXDY9szc0= github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567 h1:M8mH9eK4OUR4lu7Gd+PU1fV2/qnDNfzT635KRSObncs= github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567/go.mod h1:DWNGW8A4Y+GyBgPuaQJuWiy0XYftx4Xm/y5Jqk9I6VQ= +github.com/r3labs/sse v0.0.0-20210224172625-26fe804710bc h1:zAsgcP8MhzAbhMnB1QQ2O7ZhWYVGYSR2iVcjzQuPV+o= +github.com/r3labs/sse v0.0.0-20210224172625-26fe804710bc/go.mod h1:S8xSOnV3CgpNrWd0GQ/OoQfMtlg2uPRSuTzcSGrzwK8= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= @@ -1056,6 +1058,7 @@ golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20191116160921-f9c825593386/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -1357,6 +1360,8 @@ google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/airbrake/gobrake.v2 v2.0.9/go.mod h1:/h5ZAUhDkGaJfjzjKLSjv6zCL6O0LLBxU4K+aSYdM/U= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= +gopkg.in/cenkalti/backoff.v1 v1.1.0 h1:Arh75ttbsvlpVA7WtVpH4u9h6Zl46xuptxqLxPiSo4Y= +gopkg.in/cenkalti/backoff.v1 v1.1.0/go.mod h1:J6Vskwqd+OMVJl8C33mmtxTBs2gyzfv7UDAkHu8BrjI= gopkg.in/cenkalti/backoff.v2 v2.2.1 h1:eJ9UAg01/HIHG987TwxvnzK2MgxXq97YY6rYDpY9aII= gopkg.in/cenkalti/backoff.v2 v2.2.1/go.mod h1:S0QdOvT2AlerfSBkp0O+dk+bbIMaNbEmVk876gPCthU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/integrationtests/integration_test.go b/integrationtests/integration_test.go index 6985bd3..68682b5 100644 --- a/integrationtests/integration_test.go +++ b/integrationtests/integration_test.go @@ -3,11 +3,16 @@ package integrationtests_test import ( + "context" + "fmt" "io" "net/http" "testing" "time" + "github.com/jackc/pgx/v5" + "github.com/r3labs/sse" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -24,39 +29,106 @@ func TestIntegration(t *testing.T) { time.Sleep(1 * time.Second) for _, test := range []func(*testing.T){ - testEmptyStatsJSON, - testEmptyStatsHTML, + index, + receiveStats, + receiveFirstEvent, + receiveEventsFromBuffer, } { t.Run(testName(test), test) } } -func testEmptyStatsJSON(t *testing.T) { - r, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8383/api/v1/stats", nil) +func index(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8383/", nil) require.NoError(t, err) - r.Header.Set("Accept", "application/json") resp, err := (&http.Client{}).Do(r) require.NoError(t, err) data, err := io.ReadAll(resp.Body) require.NoError(t, err) - const expectedResp = `{"registered_users":0,"programs":0,"proofs_generated":0,"proofs_verified":0}` - require.JSONEq(t, expectedResp, string(data)) + const expectedResp = `
` + require.Contains(t, string(data), expectedResp) } -func testEmptyStatsHTML(t *testing.T) { - r, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8383/api/v1/stats", nil) - require.NoError(t, err) +func receiveStats(t *testing.T) { + events := sseClient(t, "stats") + select { + case e := <-events: + expected := `
` + require.Contains(t, string(e.Data), expected) + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } +} - r.Header.Set("Accept", "test/html") - resp, err := (&http.Client{}).Do(r) - require.NoError(t, err) +func receiveFirstEvent(t *testing.T) { + events := sseClient(t, "tx-row") + notify(t, `{"state": "submitted","tx_id": "1234","prover_id": "5678","timestamp": "2006-01-02T15:04:05Z"}`) - data, err := io.ReadAll(resp.Body) - require.NoError(t, err) + select { + case e := <-events: + expected := `
` + require.Contains(t, string(e.Data), expected) + case <-time.After(time.Second * 5): + t.Fatal("timeout") + } +} + +func receiveEventsFromBuffer(t *testing.T) { + txs := []string{ + `{"state": "submitted","tx_id": "1234","prover_id": "5678","timestamp": "2006-01-02T15:04:05Z"}`, + `{"state": "submitted","tx_id": "1234","prover_id": "5678","timestamp": "2006-01-02T15:04:05Z"}`, + `{"state": "submitted","tx_id": "1234","prover_id": "5678","timestamp": "2006-01-02T15:04:05Z"}`, + `{"state": "submitted","tx_id": "1234","prover_id": "5678","timestamp": "2006-01-02T15:04:05Z"}`, + `{"state": "submitted","tx_id": "1234","prover_id": "5678","timestamp": "2006-01-02T15:04:05Z"}`, + `{"state": "submitted","tx_id": "1234","prover_id": "5678","timestamp": "2006-01-02T15:04:05Z"}`, + `{"state": "submitted","tx_id": "1234","prover_id": "5678","timestamp": "2006-01-02T15:04:05Z"}`, + `{"state": "submitted","tx_id": "1234","prover_id": "5678","timestamp": "2006-01-02T15:04:05Z"}`, + } + + notify(t, txs...) + + // Giver server some time to buffer events before starting sse client + time.Sleep(time.Second) + events := sseClient(t, "tx-row") + + expectedEvents := len(txs) + 1 // +1 for event added by receiveFirstEvent + for i := 0; i < expectedEvents; i++ { + select { + case e := <-events: + expected := `
` + assert.Contains(t, string(e.Data), expected) + case <-time.After(time.Second * 5): + t.Fatal("timeout") + } + } +} + +func sseClient(t *testing.T, event string) chan *sse.Event { + events := make(chan *sse.Event, 100) + client := sse.NewClient("http://127.0.0.1:8383/api/v1/stream") + go func() { + err := client.SubscribeRaw(func(msg *sse.Event) { + if string(msg.Event) == event { + select { + case events <- msg: + default: + } + } + }) + assert.NoError(t, err) + }() + t.Cleanup(func() { client.Unsubscribe(events) }) + return events +} - const expectedResp = `
0
Registered
Users
0
Provers
Deployed
0
Proofs
Generated
0
Proofs
Verified
` - require.Equal(t, expectedResp, string(data)) +func notify(t *testing.T, events ...string) { + conn, err := pgx.Connect(context.Background(), "postgres://gevulot:gevulot@localhost:5432/gevulot") + require.NoError(t, err) + for _, e := range events { + _, err = conn.Exec(context.Background(), fmt.Sprintf("NOTIFY tx_events, '%s';", e)) + require.NoError(t, err) + } } diff --git a/integrationtests/utils_test.go b/integrationtests/utils_test.go index 1787273..ca4f592 100644 --- a/integrationtests/utils_test.go +++ b/integrationtests/utils_test.go @@ -8,7 +8,6 @@ import ( "os/exec" "reflect" "runtime" - "strings" "testing" "github.com/jackc/pgx/v5" @@ -18,7 +17,7 @@ import ( ) func testName(test func(*testing.T)) string { - return strings.Split(runtime.FuncForPC(reflect.ValueOf(test).Pointer()).Name(), ".")[1] + return runtime.FuncForPC(reflect.ValueOf(test).Pointer()).Name() } func buildApp(t *testing.T) { diff --git a/signalhandler/signal_handler.go b/signalhandler/signal_handler.go index 13542f7..92cd79e 100644 --- a/signalhandler/signal_handler.go +++ b/signalhandler/signal_handler.go @@ -20,7 +20,11 @@ func New(signals ...os.Signal) *SignalHandler { func (sh *SignalHandler) Run() error { signal.Notify(sh.signalsCh, sh.signals...) - s := <-sh.signalsCh + s, ok := <-sh.signalsCh + if !ok { + return nil + } + slog.Info("Signal received", slog.String("signal", s.String())) return nil } diff --git a/store/mock/store.go b/store/mock/store.go index be1d8a8..d34f273 100644 --- a/store/mock/store.go +++ b/store/mock/store.go @@ -41,7 +41,7 @@ func (s *Store) Run() error { return nil case s.events <- randomEvent(): } - time.Sleep(5 * time.Second) + time.Sleep(1 * time.Second) } } @@ -51,7 +51,6 @@ func (s *Store) Events() <-chan model.Event { func (s *Store) Stop() error { close(s.done) - close(s.events) return nil } diff --git a/store/pg/store.go b/store/pg/store.go index 39605f2..78e888a 100644 --- a/store/pg/store.go +++ b/store/pg/store.go @@ -2,17 +2,25 @@ package pg import ( + "context" "database/sql" + "encoding/json" + "errors" + "fmt" + "log/slog" + "time" "github.com/gevulotnetwork/devnet-explorer/model" "github.com/go-gorp/gorp/v3" + "github.com/jackc/pgx/v5/stdlib" _ "github.com/jackc/pgx/v5/stdlib" ) type Store struct { db *gorp.DbMap events chan model.Event - done chan struct{} + ctx context.Context + cancel context.CancelFunc } func New(dsn string) (*Store, error) { @@ -21,28 +29,53 @@ func New(dsn string) (*Store, error) { return nil, err } + ctx, cancel := context.WithCancel(context.Background()) return &Store{ db: &gorp.DbMap{Db: db, Dialect: gorp.PostgresDialect{}}, events: make(chan model.Event, 1000), - done: make(chan struct{}), + ctx: ctx, + cancel: cancel, }, nil } func (s *Store) Run() error { defer close(s.events) - eventSource := make(chan model.Event) - for { - select { - case <-s.done: - return nil - case e := <-eventSource: - select { - case <-s.done: + + conn, err := s.db.Db.Conn(context.Background()) + if err != nil { + return fmt.Errorf("failed to get connection for listen/notify: %w", err) + } + + return conn.Raw(func(driverConn any) error { + conn := driverConn.(*stdlib.Conn).Conn() + _, err := conn.Exec(context.Background(), "listen tx_events") + if err != nil { + return err + } + + for { + n, err := conn.WaitForNotification(s.ctx) + if errors.Is(err, context.Canceled) { + slog.Info("pg notify listener stopped by context") return nil + } + + if err != nil { + return fmt.Errorf("error occurred while waiting for notification: %w", err) + } + + e := model.Event{} + if err = json.Unmarshal([]byte(n.Payload), &e); err != nil { + return fmt.Errorf("notification payload '%s': %w", n.Payload, err) + } + + select { case s.events <- e: + case <-time.After(time.Minute): + return errors.New("timeout waiting for event to be sent") } } - } + }) } func (s *Store) Stats() (model.Stats, error) { @@ -66,6 +99,7 @@ func (s *Store) Events() <-chan model.Event { } func (s *Store) Stop() error { - close(s.done) + s.cancel() + s.db.Db.Close() return nil } diff --git a/testdata/tables.sql b/testdata/tables.sql index 9e0ee16..b9af764 100644 --- a/testdata/tables.sql +++ b/testdata/tables.sql @@ -188,7 +188,8 @@ CREATE TABLE public.transaction ( nonce numeric NOT NULL, signature character varying(128) NOT NULL, propagated boolean, - executed boolean + executed boolean, + created_at timestamp with time zone DEFAULT now() );