diff --git a/pkg/bootstrap/wire_gen.go b/pkg/bootstrap/wire_gen.go index fe797dd..a9298c3 100644 --- a/pkg/bootstrap/wire_gen.go +++ b/pkg/bootstrap/wire_gen.go @@ -10,6 +10,7 @@ import ( "github.com/bobvawter/cacheroach/pkg/cache" "github.com/bobvawter/cacheroach/pkg/metrics" "github.com/bobvawter/cacheroach/pkg/store/blob" + "github.com/bobvawter/cacheroach/pkg/store/cdc" "github.com/bobvawter/cacheroach/pkg/store/fs" "github.com/bobvawter/cacheroach/pkg/store/principal" "github.com/bobvawter/cacheroach/pkg/store/storetesting" @@ -60,7 +61,8 @@ func testRig(ctx context.Context) (*rig, func(), error) { DB: pool, Logger: logger, } - tokenServer, err := token.ProvideServer(configConfig, pool, logger) + notifier := cdc.ProvideNotifier(pool, logger) + tokenServer, cleanup6, err := token.ProvideServer(ctx, configConfig, pool, logger, notifier) if err != nil { cleanup5() cleanup4() @@ -79,6 +81,7 @@ func testRig(ctx context.Context) (*rig, func(), error) { } bootstrapper, err := ProvideBootstrap(ctx, store, pool, fsStore, logger, server, tokenServer, tenantServer, vhostServer) if err != nil { + cleanup6() cleanup5() cleanup4() cleanup3() @@ -90,6 +93,7 @@ func testRig(ctx context.Context) (*rig, func(), error) { Bootstrapper: bootstrapper, } return bootstrapRig, func() { + cleanup6() cleanup5() cleanup4() cleanup3() diff --git a/pkg/cmd/start/wire_gen.go b/pkg/cmd/start/wire_gen.go index 9e9f9db..4eaafaf 100644 --- a/pkg/cmd/start/wire_gen.go +++ b/pkg/cmd/start/wire_gen.go @@ -19,6 +19,7 @@ import ( "github.com/bobvawter/cacheroach/pkg/server/rest" "github.com/bobvawter/cacheroach/pkg/server/rpc" "github.com/bobvawter/cacheroach/pkg/store/blob" + "github.com/bobvawter/cacheroach/pkg/store/cdc" "github.com/bobvawter/cacheroach/pkg/store/config" "github.com/bobvawter/cacheroach/pkg/store/fs" "github.com/bobvawter/cacheroach/pkg/store/principal" @@ -65,7 +66,8 @@ func newInjector(contextContext context.Context, cacheConfig *cache.Config, conf DB: pool, Logger: logger, } - tokenServer, err := token.ProvideServer(configConfig, pool, logger) + notifier := cdc.ProvideNotifier(pool, logger) + tokenServer, cleanup4, err := token.ProvideServer(contextContext, configConfig, pool, logger, notifier) if err != nil { cleanup3() cleanup2() @@ -82,6 +84,7 @@ func newInjector(contextContext context.Context, cacheConfig *cache.Config, conf } bootstrapper, err := bootstrap.ProvideBootstrap(contextContext, store, pool, fsStore, logger, principalServer, tokenServer, tenantServer, vhostServer) if err != nil { + cleanup4() cleanup3() cleanup2() cleanup() @@ -89,14 +92,16 @@ func newInjector(contextContext context.Context, cacheConfig *cache.Config, conf } connector, err := oidc.ProvideConnector(contextContext, factory, bootstrapper, commonConfig, logger, principalServer, tokenServer) if err != nil { + cleanup4() cleanup3() cleanup2() cleanup() return nil, nil, err } sessionWrapper := rest.ProvideSessionWrapper(bootstrapper, connector, tokenServer) - vHostMap, cleanup4, err := common.ProvideVHostMap(contextContext, logger, vhostServer) + vHostMap, cleanup5, err := common.ProvideVHostMap(contextContext, logger, vhostServer) if err != nil { + cleanup4() cleanup3() cleanup2() cleanup() @@ -109,6 +114,7 @@ func newInjector(contextContext context.Context, cacheConfig *cache.Config, conf wrapper := metrics.ProvideWrapper(factory) provision, err := rest.ProvideProvision(commonConfig, connector, logger, principalServer, pProfWrapper, latchWrapper, sessionWrapper, vHostWrapper) if err != nil { + cleanup5() cleanup4() cleanup3() cleanup2() @@ -118,6 +124,7 @@ func newInjector(contextContext context.Context, cacheConfig *cache.Config, conf retrieve := rest.ProvideRetrieve(logger, fsStore, pProfWrapper, latchWrapper, sessionWrapper, vHostWrapper) authInterceptor, err := rpc.ProvideAuthInterceptor(connector, logger, tokenServer) if err != nil { + cleanup5() cleanup4() cleanup3() cleanup2() @@ -143,6 +150,7 @@ func newInjector(contextContext context.Context, cacheConfig *cache.Config, conf } uploadServer, err := upload.ProvideServer(store, configConfig, pool, fsStore, logger) if err != nil { + cleanup5() cleanup4() cleanup3() cleanup2() @@ -151,6 +159,7 @@ func newInjector(contextContext context.Context, cacheConfig *cache.Config, conf } grpcServer, err := rpc.ProvideRPC(logger, authInterceptor, busyInterceptor, elideInterceptor, interceptor, vHostInterceptor, diags, fsServer, principalServer, tenantServer, tokenServer, uploadServer, vhostServer) if err != nil { + cleanup5() cleanup4() cleanup3() cleanup2() @@ -158,8 +167,9 @@ func newInjector(contextContext context.Context, cacheConfig *cache.Config, conf return nil, nil, err } publicMux := rest.ProvidePublicMux(cliConfigHandler, connector, fileHandler, wrapper, provision, retrieve, grpcServer) - serverServer, cleanup5, err := server.ProvideServer(contextContext, busyLatch, v, commonConfig, debugMux, logger, publicMux) + serverServer, cleanup6, err := server.ProvideServer(contextContext, busyLatch, v, commonConfig, debugMux, logger, publicMux) if err != nil { + cleanup5() cleanup4() cleanup3() cleanup2() @@ -170,6 +180,7 @@ func newInjector(contextContext context.Context, cacheConfig *cache.Config, conf Server: serverServer, } return startInjector, func() { + cleanup6() cleanup5() cleanup4() cleanup3() diff --git a/pkg/enforcer/wire_gen.go b/pkg/enforcer/wire_gen.go index 89323cd..a7d0604 100644 --- a/pkg/enforcer/wire_gen.go +++ b/pkg/enforcer/wire_gen.go @@ -10,6 +10,7 @@ import ( principal2 "github.com/bobvawter/cacheroach/api/principal" tenant2 "github.com/bobvawter/cacheroach/api/tenant" token2 "github.com/bobvawter/cacheroach/api/token" + "github.com/bobvawter/cacheroach/pkg/store/cdc" "github.com/bobvawter/cacheroach/pkg/store/principal" "github.com/bobvawter/cacheroach/pkg/store/storetesting" "github.com/bobvawter/cacheroach/pkg/store/tenant" @@ -28,7 +29,8 @@ func testRig(ctx context.Context) (*rig, func(), error) { if err != nil { return nil, nil, err } - server, err := token.ProvideServer(config, pool, logger) + notifier := cdc.ProvideNotifier(pool, logger) + server, cleanup2, err := token.ProvideServer(ctx, config, pool, logger, notifier) if err != nil { cleanup() return nil, nil, err @@ -50,6 +52,7 @@ func testRig(ctx context.Context) (*rig, func(), error) { tokens: server, } return enforcerRig, func() { + cleanup2() cleanup() }, nil } diff --git a/pkg/server/wire_gen.go b/pkg/server/wire_gen.go index 373ac5c..828eff8 100644 --- a/pkg/server/wire_gen.go +++ b/pkg/server/wire_gen.go @@ -22,6 +22,7 @@ import ( "github.com/bobvawter/cacheroach/pkg/server/rest" "github.com/bobvawter/cacheroach/pkg/server/rpc" "github.com/bobvawter/cacheroach/pkg/store/blob" + "github.com/bobvawter/cacheroach/pkg/store/cdc" "github.com/bobvawter/cacheroach/pkg/store/config" "github.com/bobvawter/cacheroach/pkg/store/fs" "github.com/bobvawter/cacheroach/pkg/store/principal" @@ -84,7 +85,8 @@ func testRig(ctx context.Context) (*rig, func(), error) { DB: pool, Logger: logger, } - tokenServer, err := token.ProvideServer(configConfig, pool, logger) + notifier := cdc.ProvideNotifier(pool, logger) + tokenServer, cleanup6, err := token.ProvideServer(ctx, configConfig, pool, logger, notifier) if err != nil { cleanup5() cleanup4() @@ -103,6 +105,7 @@ func testRig(ctx context.Context) (*rig, func(), error) { } bootstrapper, err := bootstrap.ProvideBootstrap(ctx, store, pool, fsStore, logger, server, tokenServer, tenantServer, vhostServer) if err != nil { + cleanup6() cleanup5() cleanup4() cleanup3() @@ -112,6 +115,7 @@ func testRig(ctx context.Context) (*rig, func(), error) { } connector, err := oidc.ProvideConnector(ctx, factory, bootstrapper, config, logger, server, tokenServer) if err != nil { + cleanup6() cleanup5() cleanup4() cleanup3() @@ -120,8 +124,9 @@ func testRig(ctx context.Context) (*rig, func(), error) { return nil, nil, err } sessionWrapper := rest.ProvideSessionWrapper(bootstrapper, connector, tokenServer) - vHostMap, cleanup6, err := common.ProvideVHostMap(ctx, logger, vhostServer) + vHostMap, cleanup7, err := common.ProvideVHostMap(ctx, logger, vhostServer) if err != nil { + cleanup6() cleanup5() cleanup4() cleanup3() @@ -136,6 +141,7 @@ func testRig(ctx context.Context) (*rig, func(), error) { wrapper := metrics.ProvideWrapper(factory) provision, err := rest.ProvideProvision(config, connector, logger, server, pProfWrapper, latchWrapper, sessionWrapper, vHostWrapper) if err != nil { + cleanup7() cleanup6() cleanup5() cleanup4() @@ -147,6 +153,7 @@ func testRig(ctx context.Context) (*rig, func(), error) { retrieve := rest.ProvideRetrieve(logger, fsStore, pProfWrapper, latchWrapper, sessionWrapper, vHostWrapper) authInterceptor, err := rpc.ProvideAuthInterceptor(connector, logger, tokenServer) if err != nil { + cleanup7() cleanup6() cleanup5() cleanup4() @@ -174,6 +181,7 @@ func testRig(ctx context.Context) (*rig, func(), error) { } uploadServer, err := upload.ProvideServer(store, configConfig, pool, fsStore, logger) if err != nil { + cleanup7() cleanup6() cleanup5() cleanup4() @@ -184,6 +192,7 @@ func testRig(ctx context.Context) (*rig, func(), error) { } grpcServer, err := rpc.ProvideRPC(logger, authInterceptor, busyInterceptor, elideInterceptor, interceptor, vHostInterceptor, diags, fsServer, server, tenantServer, tokenServer, uploadServer, vhostServer) if err != nil { + cleanup7() cleanup6() cleanup5() cleanup4() @@ -193,8 +202,9 @@ func testRig(ctx context.Context) (*rig, func(), error) { return nil, nil, err } publicMux := rest.ProvidePublicMux(cliConfigHandler, connector, fileHandler, wrapper, provision, retrieve, grpcServer) - serverServer, cleanup7, err := ProvideServer(ctx, busyLatch, v, config, debugMux, logger, publicMux) + serverServer, cleanup8, err := ProvideServer(ctx, busyLatch, v, config, debugMux, logger, publicMux) if err != nil { + cleanup7() cleanup6() cleanup5() cleanup4() @@ -213,6 +223,7 @@ func testRig(ctx context.Context) (*rig, func(), error) { vhosts: vhostServer, } return serverRig, func() { + cleanup8() cleanup7() cleanup6() cleanup5() diff --git a/pkg/store/cdc/cdc.go b/pkg/store/cdc/cdc.go new file mode 100644 index 0000000..511f665 --- /dev/null +++ b/pkg/store/cdc/cdc.go @@ -0,0 +1,162 @@ +// Copyright 2021 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cdc contains a utility for receiving notifications whenever +// the contents of a database table are changed. +package cdc + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/Mandala/go-log" + "github.com/google/wire" + "github.com/jackc/pgx/v4/pgxpool" + "github.com/pkg/errors" +) + +var ( + // Set is used by wire. + Set = wire.NewSet(ProvideNotifier) +) + +// Notifier is a factory for CDC notification channels. +type Notifier struct { + db *pgxpool.Pool + logger *log.Logger +} + +// ProvideNotifier is called by wire. +func ProvideNotifier(db *pgxpool.Pool, logger *log.Logger) *Notifier { + return &Notifier{ + db: db, + logger: logger, + } +} + +// A Notification is emitted at least once for each data update. +type Notification struct { + Table string // The table that was updated + Key json.RawMessage // The primary key for the table + Payload json.RawMessage // The JSON payload associated with the notification +} + +func (n *Notification) String() string { + return fmt.Sprintf("%s %s %s", n.Table, string(n.Key), string(n.Payload)) +} + +// Notify creates a new CDC notification channel which will run until +// the context is canceled. +func (n *Notifier) Notify(ctx context.Context, tables []string) <-chan *Notification { + // Set the feed cursor based on the caller's now. This avoids any + // "missed" updates since it may take a measurable amount of time in + // order to actually start the feed. + l := &loop{ + Notifier: n, + ch: make(chan *Notification, 16), + resolved: fmt.Sprintf("%d.0", time.Now().UnixNano()), + tables: tables, + } + + go func() { + defer close(l.ch) + for { + err := l.run(ctx) + select { + case <-ctx.Done(): + return + case <-time.After(time.Second): + n.logger.Debugf("restarting notification loop after: %v", err) + } + } + }() + + return l.ch +} + +type loop struct { + *Notifier + ch chan *Notification + tables []string + resolved string +} + +func (l *loop) run(ctx context.Context) error { + const ts = 10 * time.Second + const watch = 3 * ts + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + watchdog := time.NewTicker(watch) + defer watchdog.Stop() + go func() { + select { + case <-ctx.Done(): + case <-watchdog.C: + l.logger.Warnf("cdc watchdog timer firing") + cancel() + } + }() + + s := fmt.Sprintf( + "EXPERIMENTAL CHANGEFEED FOR %s WITH resolved='%s', no_initial_scan", + strings.Join(l.tables, ","), ts) + if l.resolved != "" { + s = fmt.Sprintf("%s, cursor='%s'", s, l.resolved) + } + l.logger.Tracef("creating changefeed using %q", s) + + rows, err := l.db.Query(ctx, s) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + // We'll see a NULL value for resolved-timestamp notifications. + var maybeTable *string + out := &Notification{} + if err := rows.Scan(&maybeTable, &out.Key, &out.Payload); err != nil { + return err + } + watchdog.Reset(watch) + + var envelope struct { + After json.RawMessage `json:"after"` + Resolved string `json:"resolved"` + } + if err := json.Unmarshal(out.Payload, &envelope); err != nil { + return errors.Wrap(err, "decoding envelope") + } + if envelope.Resolved != "" { + l.resolved = envelope.Resolved + l.logger.Tracef("updated resolved timestamp: %s", envelope.Resolved) + continue + } + if maybeTable != nil && len(envelope.After) > 0 { + out.Table = *maybeTable + out.Payload = envelope.After + + select { + case <-ctx.Done(): + return ctx.Err() + case l.ch <- out: + } + } + } + return rows.Err() +} diff --git a/pkg/store/cdc/cdc_test.go b/pkg/store/cdc/cdc_test.go new file mode 100644 index 0000000..0588e1e --- /dev/null +++ b/pkg/store/cdc/cdc_test.go @@ -0,0 +1,54 @@ +// Copyright 2021 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cdc + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func Test(t *testing.T) { + a := assert.New(t) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + rig, cleanup, err := testRig(ctx) + if !a.NoError(err) { + return + } + defer cleanup() + + ch := rig.notifier.Notify(ctx, []string{"principals"}) + + if _, err := rig.db.Exec(ctx, + "INSERT INTO principals (principal, version) "+ + "VALUES (gen_random_uuid(), 1)", + ); !a.NoError(err) { + return + } + + select { + case <-ctx.Done(): + a.Fail("timed out") + case n := <-ch: + a.Equal("principals", n.Table) + if a.NotEmpty(n.Key) { + a.Equal(uint8('['), n.Key[0]) + } + a.NotEmpty(n.Payload) + } +} diff --git a/pkg/store/cdc/test_rig.go b/pkg/store/cdc/test_rig.go new file mode 100644 index 0000000..acd41cf --- /dev/null +++ b/pkg/store/cdc/test_rig.go @@ -0,0 +1,37 @@ +// Copyright 2021 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//+build wireinject + +package cdc + +import ( + "context" + + "github.com/bobvawter/cacheroach/pkg/store/storetesting" + "github.com/google/wire" + "github.com/jackc/pgx/v4/pgxpool" +) + +type rig struct { + db *pgxpool.Pool + notifier *Notifier +} + +func testRig(ctx context.Context) (*rig, func(), error) { + panic(wire.Build( + Set, + storetesting.Set, + wire.Struct(new(rig), "*"), + )) +} diff --git a/pkg/store/cdc/wire_gen.go b/pkg/store/cdc/wire_gen.go new file mode 100644 index 0000000..f71b58a --- /dev/null +++ b/pkg/store/cdc/wire_gen.go @@ -0,0 +1,45 @@ +// Code generated by Wire. DO NOT EDIT. + +//go:generate go run github.com/google/wire/cmd/wire +//+build !wireinject + +package cdc + +import ( + "context" + "github.com/bobvawter/cacheroach/pkg/store/storetesting" + "github.com/jackc/pgx/v4/pgxpool" +) + +// Injectors from test_rig.go: + +func testRig(ctx context.Context) (*rig, func(), error) { + config, err := storetesting.ProvideStoreConfig() + if err != nil { + return nil, nil, err + } + logger := _wireLoggerValue + pool, cleanup, err := storetesting.ProvideDB(ctx, config, logger) + if err != nil { + return nil, nil, err + } + notifier := ProvideNotifier(pool, logger) + cdcRig := &rig{ + db: pool, + notifier: notifier, + } + return cdcRig, func() { + cleanup() + }, nil +} + +var ( + _wireLoggerValue = storetesting.Logger +) + +// test_rig.go: + +type rig struct { + db *pgxpool.Pool + notifier *Notifier +} diff --git a/pkg/store/schema/schema.go b/pkg/store/schema/schema.go index e66223f..2656eaf 100644 --- a/pkg/store/schema/schema.go +++ b/pkg/store/schema/schema.go @@ -39,36 +39,11 @@ func EnsureSchema(ctx context.Context, db *pgxpool.Pool, logger *log.Logger) err if err := tx.Commit(ctx); err != nil { return err } - logger.Info("schema setup complete") - - return nil -} - -// TruncateSchema truncates all tables in the schema. -func TruncateSchema(ctx context.Context, db *pgxpool.Pool, logger *log.Logger) error { - // Use cascade to do full cleanup. - names := []string{ - "chunks", - "ropes", - "tenants", - "principals", - "vhosts", - } - tx, err := db.Begin(ctx) - if err != nil { + if _, err := db.Exec(ctx, "SET CLUSTER SETTING kv.rangefeed.enabled = true"); err != nil { return err } - defer tx.Rollback(ctx) + logger.Info("schema setup complete") - for _, name := range names { - if _, err := tx.Exec(ctx, "TRUNCATE "+name+" CASCADE"); err != nil { - return err - } - } - if err := tx.Commit(ctx); err != nil { - return err - } - logger.Info("truncated all tables") return nil } diff --git a/pkg/store/set.go b/pkg/store/set.go index 3918c69..bfb7a92 100644 --- a/pkg/store/set.go +++ b/pkg/store/set.go @@ -15,6 +15,7 @@ package store import ( "github.com/bobvawter/cacheroach/pkg/store/blob" + "github.com/bobvawter/cacheroach/pkg/store/cdc" "github.com/bobvawter/cacheroach/pkg/store/fs" "github.com/bobvawter/cacheroach/pkg/store/principal" "github.com/bobvawter/cacheroach/pkg/store/tenant" @@ -29,6 +30,7 @@ import ( // Combine with storetesting.Set for a ready-to-run stack. var Set = wire.NewSet( blob.Set, + cdc.Set, fs.Set, principal.Set, tenant.Set, diff --git a/pkg/store/token/test_rig.go b/pkg/store/token/test_rig.go index b978857..f794fe1 100644 --- a/pkg/store/token/test_rig.go +++ b/pkg/store/token/test_rig.go @@ -18,6 +18,7 @@ package token import ( "context" + "github.com/bobvawter/cacheroach/pkg/store/cdc" "github.com/bobvawter/cacheroach/pkg/store/principal" "github.com/bobvawter/cacheroach/pkg/store/storetesting" "github.com/bobvawter/cacheroach/pkg/store/tenant" @@ -33,6 +34,7 @@ type rig struct { func testRig(ctx context.Context) (*rig, func(), error) { panic(wire.Build( Set, + cdc.Set, storetesting.Set, principal.Set, tenant.Set, diff --git a/pkg/store/token/token.go b/pkg/store/token/token.go index 96c6dab..1d0f161 100644 --- a/pkg/store/token/token.go +++ b/pkg/store/token/token.go @@ -15,6 +15,7 @@ package token import ( "context" + "encoding/json" "time" "github.com/Mandala/go-log" @@ -24,9 +25,11 @@ import ( "github.com/bobvawter/cacheroach/api/tenant" "github.com/bobvawter/cacheroach/api/token" "github.com/bobvawter/cacheroach/pkg/claims" + "github.com/bobvawter/cacheroach/pkg/store/cdc" "github.com/bobvawter/cacheroach/pkg/store/config" "github.com/bobvawter/cacheroach/pkg/store/util" "github.com/dgrijalva/jwt-go/v4" + "github.com/google/uuid" "github.com/google/wire" lru "github.com/hashicorp/golang-lru" "github.com/jackc/pgconn" @@ -66,19 +69,45 @@ var _ token.TokensServer = (*Server)(nil) // ProvideServer is called by wire. func ProvideServer( + ctx context.Context, cfg *config.Config, db *pgxpool.Pool, logger *log.Logger, -) (*Server, error) { + notifier *cdc.Notifier, +) (*Server, func(), error) { if len(cfg.SigningKeys) == 0 { - return nil, errors.New("HMAC signing keys must be specified") + return nil, nil, errors.New("HMAC signing keys must be specified") } validations, err := lru.New2Q(1024 * 1024) if err != nil { - return nil, err + return nil, nil, err } + + ctx, cancel := context.WithCancel(ctx) + go func() { + ch := notifier.Notify(ctx, []string{"sessions"}) + for { + select { + case <-ctx.Done(): + return + case n := <-ch: + var payload struct { + Principal uuid.UUID `json:"principal"` + Session uuid.UUID `json:"session"` + } + if err := json.Unmarshal(n.Payload, &payload); err != nil { + logger.Warnf("could not unmarshal session notification: %v", err) + continue + } + logger.Tracef("invalidating session %s for %s", payload.Session, payload.Principal) + validations.Remove(payload.Principal) + validations.Remove(payload.Session) + } + } + }() + s := &Server{config: cfg, db: db, logger: logger, cache: validations} - return s, nil + return s, cancel, nil } // Current implements TokensServer. @@ -99,11 +128,12 @@ func (s *Server) Find(scope *session.Scope, server token.Tokens_FindServer) erro cached := found.(*cached) if cached.expires.After(time.Now()) { for i := range cached.sessions { - server.Send(cached.sessions[i]) + if err := server.Send(cached.sessions[i]); err != nil { + return err + } } return nil } - s.cache.Remove(cacheKey) } return util.RetryLoop(ctx, func(ctx context.Context, sideEffect *util.Marker) error { rows, err := s.db.Query(ctx, ` @@ -120,11 +150,9 @@ WHERE expires_at > now() } defer rows.Close() - const maxCachedSession = 16 - cacheOK := true var cache []*session.Session + var earliestExpiration time.Time - sideEffect.Mark() for rows.Next() { var caps capabilities.Capabilities var expires time.Time @@ -159,26 +187,21 @@ WHERE expires_at > now() continue } + sideEffect.Mark() if err := server.Send(s); err != nil { return err } - if cacheOK { - if len(cache) < maxCachedSession { - cache = append(cache, s) - } else { - cacheOK = false - cache = nil - } + cache = append(cache, s) + if earliestExpiration.IsZero() || expires.Before(earliestExpiration) { + earliestExpiration = expires } } - if cacheOK { - s.cache.Add(cacheKey, &cached{ - expires: time.Now().Add(time.Minute), - sessions: cache, - }) - } + s.cache.Add(cacheKey, &cached{ + expires: earliestExpiration, + sessions: cache, + }) return nil }) diff --git a/pkg/store/token/wire_gen.go b/pkg/store/token/wire_gen.go index c475beb..c49f750 100644 --- a/pkg/store/token/wire_gen.go +++ b/pkg/store/token/wire_gen.go @@ -7,6 +7,7 @@ package token import ( "context" + "github.com/bobvawter/cacheroach/pkg/store/cdc" "github.com/bobvawter/cacheroach/pkg/store/principal" "github.com/bobvawter/cacheroach/pkg/store/storetesting" "github.com/bobvawter/cacheroach/pkg/store/tenant" @@ -24,7 +25,8 @@ func testRig(ctx context.Context) (*rig, func(), error) { if err != nil { return nil, nil, err } - server, err := ProvideServer(config, pool, logger) + notifier := cdc.ProvideNotifier(pool, logger) + server, cleanup2, err := ProvideServer(ctx, config, pool, logger, notifier) if err != nil { cleanup() return nil, nil, err @@ -44,6 +46,7 @@ func testRig(ctx context.Context) (*rig, func(), error) { principals: principalServer, } return tokenRig, func() { + cleanup2() cleanup() }, nil } diff --git a/pkg/store/upload/test_rig.go b/pkg/store/upload/test_rig.go index 2bd6977..ea32eec 100644 --- a/pkg/store/upload/test_rig.go +++ b/pkg/store/upload/test_rig.go @@ -19,6 +19,7 @@ import ( "context" "github.com/bobvawter/cacheroach/pkg/store/blob" + "github.com/bobvawter/cacheroach/pkg/store/cdc" "github.com/bobvawter/cacheroach/pkg/store/fs" "github.com/bobvawter/cacheroach/pkg/store/principal" "github.com/bobvawter/cacheroach/pkg/store/storetesting" @@ -39,6 +40,7 @@ func testRig(ctx context.Context) (*rig, func(), error) { panic(wire.Build( Set, blob.Set, + cdc.Set, fs.Set, storetesting.Set, principal.Set, diff --git a/pkg/store/upload/upload.go b/pkg/store/upload/upload.go index ded889c..a32df92 100644 --- a/pkg/store/upload/upload.go +++ b/pkg/store/upload/upload.go @@ -233,7 +233,7 @@ func (s *Server) Fetch(ctx context.Context, req *upload.FetchRequest) (*upload.F r.Header.Add("x-cacheroach-session", sn.ID.AsUUID().String()) files := s.fs.FileSystem(req.Tenant) - if f, err := files.Open(req.Path); err == nil { + if f, err := files.OpenVersion(ctx, req.Path, -1); err == nil { if stat, err := f.Stat(); err == nil { r.Header.Add("if-modified-since", stat.ModTime().Format(http.TimeFormat)) } diff --git a/pkg/store/upload/wire_gen.go b/pkg/store/upload/wire_gen.go index 4e5ca4f..6a4fc65 100644 --- a/pkg/store/upload/wire_gen.go +++ b/pkg/store/upload/wire_gen.go @@ -10,6 +10,7 @@ import ( "github.com/bobvawter/cacheroach/pkg/cache" "github.com/bobvawter/cacheroach/pkg/metrics" "github.com/bobvawter/cacheroach/pkg/store/blob" + "github.com/bobvawter/cacheroach/pkg/store/cdc" "github.com/bobvawter/cacheroach/pkg/store/fs" "github.com/bobvawter/cacheroach/pkg/store/principal" "github.com/bobvawter/cacheroach/pkg/store/storetesting" @@ -72,7 +73,8 @@ func testRig(ctx context.Context) (*rig, func(), error) { DB: pool, Logger: logger, } - tokenServer, err := token.ProvideServer(configConfig, pool, logger) + notifier := cdc.ProvideNotifier(pool, logger) + tokenServer, cleanup6, err := token.ProvideServer(ctx, configConfig, pool, logger, notifier) if err != nil { cleanup5() cleanup4() @@ -89,6 +91,7 @@ func testRig(ctx context.Context) (*rig, func(), error) { tokens: tokenServer, } return uploadRig, func() { + cleanup6() cleanup5() cleanup4() cleanup3()