diff --git a/cmd/follower.go b/cmd/follower.go index 0d4ab93c..ca70955b 100644 --- a/cmd/follower.go +++ b/cmd/follower.go @@ -108,6 +108,11 @@ func follower(_ *cobra.Command, _ []string) error { return fmt.Errorf("failed to parse raft.logdb: %w", err) } + nQueue := storage.NewNotificationQueue() + go nQueue.Run() + defer func() { + _ = nQueue.Close() + }() engine, err := storage.New(storage.Config{ Log: engineLog.Sugar(), ClientAddress: viper.GetString("api.advertise-address"), @@ -129,16 +134,17 @@ func follower(_ *cobra.Command, _ []string) error { NodeName: viper.GetString("memberlist.node-name"), }, Table: storage.TableConfig{ - FS: vfs.Default, - ElectionRTT: viper.GetUint64("raft.election-rtt"), - HeartbeatRTT: viper.GetUint64("raft.heartbeat-rtt"), - SnapshotEntries: viper.GetUint64("raft.snapshot-entries"), - CompactionOverhead: viper.GetUint64("raft.compaction-overhead"), - MaxInMemLogSize: viper.GetUint64("raft.max-in-mem-log-size"), - DataDir: viper.GetString("raft.state-machine-dir"), - RecoveryType: toRecoveryType(viper.GetString("raft.snapshot-recovery-type")), - BlockCacheSize: viper.GetInt64("storage.block-cache-size"), - TableCacheSize: viper.GetInt("storage.table-cache-size"), + FS: vfs.Default, + ElectionRTT: viper.GetUint64("raft.election-rtt"), + HeartbeatRTT: viper.GetUint64("raft.heartbeat-rtt"), + SnapshotEntries: viper.GetUint64("raft.snapshot-entries"), + CompactionOverhead: viper.GetUint64("raft.compaction-overhead"), + MaxInMemLogSize: viper.GetUint64("raft.max-in-mem-log-size"), + DataDir: viper.GetString("raft.state-machine-dir"), + RecoveryType: toRecoveryType(viper.GetString("raft.snapshot-recovery-type")), + BlockCacheSize: viper.GetInt64("storage.block-cache-size"), + TableCacheSize: viper.GetInt("storage.table-cache-size"), + AppliedIndexListener: nQueue.Notify, }, Meta: storage.MetaConfig{ ElectionRTT: viper.GetUint64("raft.election-rtt"), @@ -158,15 +164,14 @@ func follower(_ *cobra.Command, _ []string) error { defer engine.Close() // Replication + conn, err := createReplicationConn() + defer func() { + _ = conn.Close() + }() + if err != nil { + return fmt.Errorf("cannot create replication conn: %w", err) + } { - conn, err := createReplicationConn() - defer func() { - _ = conn.Close() - }() - if err != nil { - return fmt.Errorf("cannot create replication conn: %w", err) - } - d := replication.NewManager(engine.Manager, engine.NodeHost, conn, replication.Config{ ReconcileInterval: viper.GetDuration("replication.reconcile-interval"), Workers: replication.WorkerConfig{ @@ -192,11 +197,7 @@ func follower(_ *cobra.Command, _ []string) error { if err != nil { return fmt.Errorf("failed to create API server: %w", err) } - regattapb.RegisterKVServer(regatta, ®attaserver.ReadonlyKVServer{ - KVServer: regattaserver.KVServer{ - Storage: engine, - }, - }) + regattapb.RegisterKVServer(regatta, regattaserver.NewForwardingKVServer(engine, regattapb.NewKVClient(conn), nQueue)) regattapb.RegisterClusterServer(regatta, ®attaserver.ClusterServer{ Cluster: engine, Config: viperConfigReader, diff --git a/cmd/leader.go b/cmd/leader.go index eaabdd51..698cba3f 100644 --- a/cmd/leader.go +++ b/cmd/leader.go @@ -226,6 +226,7 @@ func leader(_ *cobra.Command, _ []string) error { ) regattapb.RegisterMetadataServer(replication, ®attaserver.MetadataServer{Tables: engine}) regattapb.RegisterSnapshotServer(replication, ®attaserver.SnapshotServer{Tables: engine}) + regattapb.RegisterKVServer(replication, ®attaserver.KVServer{Storage: engine}) regattapb.RegisterLogServer(replication, ls) // Start server go func() { diff --git a/go.mod b/go.mod index ef04ea76..f25b4556 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/benbjohnson/clock v1.3.5 github.com/cenkalti/backoff/v4 v4.2.1 github.com/cockroachdb/pebble v0.0.0-20221207173255-0f086d933dac + github.com/google/uuid v1.4.0 github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/hashicorp/memberlist v0.5.0 @@ -48,14 +49,13 @@ require ( github.com/cockroachdb/redact v1.1.5 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.3 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/envoyproxy/protoc-gen-validate v1.0.2 // indirect + github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect - github.com/getsentry/sentry-go v0.25.0 // indirect + github.com/getsentry/sentry-go v0.26.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/btree v1.1.2 // indirect - github.com/google/uuid v1.4.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-immutable-radix v1.3.1 // indirect github.com/hashicorp/go-msgpack v0.5.5 // indirect @@ -64,11 +64,13 @@ require ( github.com/hashicorp/golang-lru v1.0.2 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/huandu/xstrings v1.4.0 // indirect + github.com/iancoleman/strcase v0.3.0 // indirect github.com/imdario/mergo v0.3.13 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/lni/goutils v1.4.0 // indirect + github.com/lyft/protoc-gen-star/v2 v2.0.3 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/miekg/dns v1.1.56 // indirect github.com/mitchellh/copystructure v1.2.0 // indirect diff --git a/go.sum b/go.sum index 0e97dcdd..978ed0a7 100644 --- a/go.sum +++ b/go.sum @@ -92,6 +92,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/envoyproxy/protoc-gen-validate v1.0.2 h1:QkIBuU5k+x7/QXPvPPnWXWlCdaBFApVqftFV6k087DA= github.com/envoyproxy/protoc-gen-validate v1.0.2/go.mod h1:GpiZQP3dDbg4JouG/NNS7QWXpgx6x8QiMKdmN72jogE= +github.com/envoyproxy/protoc-gen-validate v1.0.4 h1:gVPz/FMfvh57HdSJQyvBtF00j8JU4zdyUgIUNhlgg0A= +github.com/envoyproxy/protoc-gen-validate v1.0.4/go.mod h1:qys6tmnRsYrQqIhm2bvKZH4Blx/1gTIZ2UKVY1M+Yew= github.com/etcd-io/bbolt v1.3.3/go.mod h1:ZF2nL25h33cCyBtcyWeZ2/I3HQOfTP+0PIEvHjkjCrw= github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072/go.mod h1:duJ4Jxv5lDcvg4QuQr0oowTf7dz4/CR8NtyCooz9HL8= github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= @@ -106,6 +108,8 @@ github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyT github.com/gavv/httpexpect v2.0.0+incompatible/go.mod h1:x+9tiU1YnrOvnB725RkpoLv1M62hOWzwo5OXotisrKc= github.com/getsentry/sentry-go v0.25.0 h1:q6Eo+hS+yoJlTO3uu/azhQadsD8V+jQn2D8VvX1eOyI= github.com/getsentry/sentry-go v0.25.0/go.mod h1:lc76E2QywIyW8WuBnwl8Lc4bkmQH4+w1gwTf25trprY= +github.com/getsentry/sentry-go v0.26.0 h1:IX3++sF6/4B5JcevhdZfdKIHfyvMmAq/UnqcyT2H6mA= +github.com/getsentry/sentry-go v0.26.0/go.mod h1:lc76E2QywIyW8WuBnwl8Lc4bkmQH4+w1gwTf25trprY= github.com/ghemawat/stream v0.0.0-20171120220530-696b145b53b9/go.mod h1:106OIgooyS7OzLDOpUGgm9fA3bQENb/cFSyyBmMoJDs= github.com/gin-contrib/sse v0.0.0-20190301062529-5545eab6dad3/go.mod h1:VJ0WA2NBN22VlZ2dKZQPAPnyWw5XTlK1KymzLKsr59s= github.com/gin-gonic/gin v1.4.0/go.mod h1:OW2EZn3DO8Ln9oIKOvM++LBO+5UPHJJDH72/q/3rZdM= @@ -207,6 +211,8 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO github.com/huandu/xstrings v1.4.0 h1:D17IlohoQq4UcpqD7fDk80P7l+lwAmlFaBHgOipl2FU= github.com/huandu/xstrings v1.4.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/hydrogen18/memlistener v0.0.0-20141126152155-54553eb933fb/go.mod h1:qEIFzExnS6016fRpRfxrExeVn2gbClQA99gQhnIcdhE= +github.com/iancoleman/strcase v0.3.0 h1:nTXanmYxhfFAMjZL34Ov6gkzEsSJZ5DbhxWjvSASxEI= +github.com/iancoleman/strcase v0.3.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= github.com/imdario/mergo v0.3.13 h1:lFzP57bqS/wsqKssCGmtLAb8A0wKjLGrve2q3PPVcBk= github.com/imdario/mergo v0.3.13/go.mod h1:4lJ1jqUDcsbIECGy0RUJAXNIhg+6ocWgb1ALK2O4oXg= github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA= @@ -257,6 +263,8 @@ github.com/lni/goutils v1.4.0 h1:e1tNN+4zsbTpNvhG5cxirkH9Pdz96QAZ2j6+5tmjvqg= github.com/lni/goutils v1.4.0/go.mod h1:LIHvF0fflR+zyXUQFQOiHPpKANf3UIr7DFIv5CBPOoU= github.com/lni/vfs v0.2.1-0.20220616104132-8852fd867376 h1:jX9CoRWNPwrZ2yY3RJFTSwa49qDQqtXglrCByGdQGZg= github.com/lni/vfs v0.2.1-0.20220616104132-8852fd867376/go.mod h1:LOatfyR8Xeej1jbXybwYGVfCccR0u+BQRG9xg7BD7xo= +github.com/lyft/protoc-gen-star/v2 v2.0.3 h1:/3+/2sWyXeMLzKd1bX+ixWKgEMsULrIivpDsuaF441o= +github.com/lyft/protoc-gen-star/v2 v2.0.3/go.mod h1:amey7yeodaJhXSbf/TlLvWiqQfLOSpEk//mLlc+axEk= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= diff --git a/regattaserver/kv.go b/regattaserver/kv.go index 6bf6f4ce..0a53122d 100644 --- a/regattaserver/kv.go +++ b/regattaserver/kv.go @@ -5,8 +5,10 @@ package regattaserver import ( "context" "errors" + "fmt" "github.com/jamf/regatta/regattapb" + "github.com/jamf/regatta/storage" serrors "github.com/jamf/regatta/storage/errors" "github.com/jamf/regatta/util/iter" "google.golang.org/grpc/codes" @@ -179,19 +181,48 @@ func (s *KVServer) Txn(ctx context.Context, req *regattapb.TxnRequest) (*regatta return r, nil } -// ReadonlyKVServer implements read part of KV service from proto/regatta.proto. -type ReadonlyKVServer struct { +func NewForwardingKVServer(storage KVService, client regattapb.KVClient, q *storage.IndexNotificationQueue) *ForwardingKVServer { + return &ForwardingKVServer{ + KVServer: KVServer{Storage: storage}, + client: client, + q: q, + } +} + +type propagationQueue interface { + Add(ctx context.Context, table string, revision uint64) <-chan error +} + +// ForwardingKVServer forwards the write operations to the leader cluster. +type ForwardingKVServer struct { KVServer + client regattapb.KVClient + q propagationQueue } // Put implements proto/regatta.proto KV.Put method. -func (r *ReadonlyKVServer) Put(_ context.Context, _ *regattapb.PutRequest) (*regattapb.PutResponse, error) { - return nil, status.Error(codes.Unimplemented, "method Put not implemented for follower") +func (r *ForwardingKVServer) Put(ctx context.Context, req *regattapb.PutRequest) (*regattapb.PutResponse, error) { + put, err := r.client.Put(ctx, req) + if err != nil { + if s, ok := status.FromError(err); ok { + return nil, status.Error(s.Code(), fmt.Sprintf("leader error: %v", s.Err())) + } + return nil, status.Error(codes.FailedPrecondition, "forward error") + } + + return put, <-r.q.Add(ctx, string(req.Table), put.Header.Revision) } // DeleteRange implements proto/regatta.proto KV.DeleteRange method. -func (r *ReadonlyKVServer) DeleteRange(_ context.Context, _ *regattapb.DeleteRangeRequest) (*regattapb.DeleteRangeResponse, error) { - return nil, status.Error(codes.Unimplemented, "method DeleteRange not implemented for follower") +func (r *ForwardingKVServer) DeleteRange(ctx context.Context, req *regattapb.DeleteRangeRequest) (*regattapb.DeleteRangeResponse, error) { + del, err := r.client.DeleteRange(ctx, req) + if err != nil { + if s, ok := status.FromError(err); ok { + return nil, status.Error(s.Code(), fmt.Sprintf("leader error: %v", s.Err())) + } + return nil, status.Error(codes.FailedPrecondition, "forward error") + } + return del, <-r.q.Add(ctx, string(req.Table), del.Header.Revision) } // Txn processes multiple requests in a single transaction. @@ -199,9 +230,16 @@ func (r *ReadonlyKVServer) DeleteRange(_ context.Context, _ *regattapb.DeleteRan // and generates events with the same revision for every completed request. // It is allowed to modify the same key several times within one txn (the result will be the last Op that modified the key). // Readonly transactions allowed using follower API. -func (r *ReadonlyKVServer) Txn(ctx context.Context, req *regattapb.TxnRequest) (*regattapb.TxnResponse, error) { +func (r *ForwardingKVServer) Txn(ctx context.Context, req *regattapb.TxnRequest) (*regattapb.TxnResponse, error) { if req.IsReadonly() { return r.KVServer.Txn(ctx, req) } - return nil, status.Error(codes.Unimplemented, "writable Txn not implemented for follower") + txn, err := r.client.Txn(ctx, req) + if err != nil { + if s, ok := status.FromError(err); ok { + return nil, status.Error(s.Code(), fmt.Sprintf("leader error: %v", s.Err())) + } + return nil, status.Error(codes.FailedPrecondition, "forward error") + } + return txn, <-r.q.Add(ctx, string(req.Table), txn.Header.Revision) } diff --git a/regattaserver/kv_test.go b/regattaserver/kv_test.go index 31ee218a..a3608507 100644 --- a/regattaserver/kv_test.go +++ b/regattaserver/kv_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -532,51 +533,64 @@ func TestKVServer_Txn(t *testing.T) { r.NoError(err) } -func TestReadonlyKVServer_Put(t *testing.T) { +func TestForwardingKVServer_Put(t *testing.T) { r := require.New(t) - kv := ReadonlyKVServer{ + client := &mockClient{} + kv := ForwardingKVServer{ KVServer: KVServer{ Storage: &mockKVService{}, }, + client: client, + q: fakeQueue{}, } - - t.Log("Put kv") - _, err := kv.Put(context.Background(), ®attapb.PutRequest{ + ctx := context.Background() + req := ®attapb.PutRequest{ Table: table1Name, Key: key1Name, - }) - r.EqualError(err, status.Errorf(codes.Unimplemented, "method Put not implemented for follower").Error()) + } + client.On("Put", ctx, req, mock.Anything).Return(®attapb.PutResponse{Header: ®attapb.ResponseHeader{Revision: 1}}, nil) + t.Log("Put kv") + resp, err := kv.Put(ctx, req) + r.NoError(err) + r.Equal(uint64(1), resp.Header.Revision) } -func TestReadonlyKVServer_DeleteRange(t *testing.T) { +func TestForwardingKVServer_DeleteRange(t *testing.T) { r := require.New(t) - kv := ReadonlyKVServer{ + client := &mockClient{} + kv := ForwardingKVServer{ KVServer: KVServer{ Storage: &mockKVService{}, }, + client: client, + q: fakeQueue{}, } - - t.Log("Delete existing kv") - _, err := kv.DeleteRange(context.Background(), ®attapb.DeleteRangeRequest{ + ctx := context.Background() + req := ®attapb.DeleteRangeRequest{ Table: table1Name, Key: key1Name, - }) - r.EqualError(err, status.Errorf(codes.Unimplemented, "method DeleteRange not implemented for follower").Error()) + } + client.On("DeleteRange", ctx, req, mock.Anything).Return(®attapb.DeleteRangeResponse{Header: ®attapb.ResponseHeader{Revision: 1}}, nil) + t.Log("Delete existing kv") + resp, err := kv.DeleteRange(ctx, req) + r.NoError(err) + r.Equal(uint64(1), resp.Header.Revision) } -func TestReadonlyKVServer_Txn(t *testing.T) { +func TestForwardingKVServer_Txn(t *testing.T) { r := require.New(t) storage := &mockKVService{} - kv := ReadonlyKVServer{ + client := &mockClient{} + kv := ForwardingKVServer{ KVServer: KVServer{ Storage: storage, }, + client: client, + q: fakeQueue{}, } - storage.On("Txn", mock.Anything, mock.AnythingOfType("*regattapb.TxnRequest")).Return(®attapb.TxnResponse{}, nil) - - t.Log("Writable Txn") - _, err := kv.Txn(context.Background(), ®attapb.TxnRequest{ + ctx := context.Background() + req := ®attapb.TxnRequest{ Success: []*regattapb.RequestOp{ { Request: ®attapb.RequestOp_RequestPut{RequestPut: ®attapb.RequestOp_Put{ @@ -584,11 +598,14 @@ func TestReadonlyKVServer_Txn(t *testing.T) { }}, }, }, - }) - r.EqualError(err, status.Errorf(codes.Unimplemented, "writable Txn not implemented for follower").Error()) + } + t.Log("Writable Txn") + client.On("Txn", ctx, req, mock.Anything).Return(®attapb.TxnResponse{Header: ®attapb.ResponseHeader{Revision: 1}}, nil) + resp, err := kv.Txn(ctx, req) + r.NoError(err) + r.Equal(uint64(1), resp.Header.Revision) - t.Log("Readonly Txn") - _, err = kv.Txn(context.Background(), ®attapb.TxnRequest{ + req = ®attapb.TxnRequest{ Table: table1Name, Success: []*regattapb.RequestOp{ { @@ -597,7 +614,11 @@ func TestReadonlyKVServer_Txn(t *testing.T) { }}, }, }, - }) + } + ctx = context.Background() + storage.On("Txn", ctx, req).Return(®attapb.TxnResponse{}, nil) + t.Log("Readonly Txn") + _, err = kv.Txn(ctx, req) r.NoError(err) } @@ -662,3 +683,40 @@ func (m *mockIterateRangeServer) SendMsg(mes any) error { func (m *mockIterateRangeServer) RecvMsg(mes any) error { return m.Mock.Called(mes).Error(0) } + +type mockClient struct { + mock.Mock +} + +func (m *mockClient) Range(ctx context.Context, in *regattapb.RangeRequest, opts ...grpc.CallOption) (*regattapb.RangeResponse, error) { + called := m.Mock.Called(ctx, in, opts) + return called.Get(0).(*regattapb.RangeResponse), called.Error(1) +} + +func (m *mockClient) IterateRange(ctx context.Context, in *regattapb.RangeRequest, opts ...grpc.CallOption) (regattapb.KV_IterateRangeClient, error) { + called := m.Mock.Called(ctx, in, opts) + return called.Get(0).(regattapb.KV_IterateRangeClient), called.Error(1) +} + +func (m *mockClient) Put(ctx context.Context, in *regattapb.PutRequest, opts ...grpc.CallOption) (*regattapb.PutResponse, error) { + called := m.Mock.Called(ctx, in, opts) + return called.Get(0).(*regattapb.PutResponse), called.Error(1) +} + +func (m *mockClient) DeleteRange(ctx context.Context, in *regattapb.DeleteRangeRequest, opts ...grpc.CallOption) (*regattapb.DeleteRangeResponse, error) { + called := m.Mock.Called(ctx, in, opts) + return called.Get(0).(*regattapb.DeleteRangeResponse), called.Error(1) +} + +func (m *mockClient) Txn(ctx context.Context, in *regattapb.TxnRequest, opts ...grpc.CallOption) (*regattapb.TxnResponse, error) { + called := m.Mock.Called(ctx, in, opts) + return called.Get(0).(*regattapb.TxnResponse), called.Error(1) +} + +type fakeQueue struct{} + +func (f fakeQueue) Add(ctx context.Context, table string, revision uint64) <-chan error { + i := make(chan error) + close(i) + return i +} diff --git a/storage/engine.go b/storage/engine.go index 0de06efc..38ff55ef 100644 --- a/storage/engine.go +++ b/storage/engine.go @@ -56,7 +56,7 @@ func New(cfg Config) (*Engine, error) { }, ) if cfg.LogCacheSize > 0 { - e.LogCache = &logreader.ShardCache{ShardCacheSize: cfg.LogCacheSize} + e.LogCache = logreader.NewShardCache(cfg.LogCacheSize) e.LogReader = &logreader.Cached{LogQuerier: nh, ShardCache: e.LogCache} } else { e.LogReader = &logreader.Simple{LogQuerier: nh} diff --git a/storage/engine_events.go b/storage/engine_events.go index 2367f959..abbf2cdc 100644 --- a/storage/engine_events.go +++ b/storage/engine_events.go @@ -17,12 +17,7 @@ func (e *events) dispatchEvents() { switch ev := evt.(type) { case nodeHostShuttingDown: return - case leaderUpdated, nodeUnloaded, membershipChanged: - e.engine.Cluster.Notify() - case nodeReady: - if ev.ReplicaID == e.engine.cfg.NodeID && e.engine.LogCache != nil { - e.engine.LogCache.NodeReady(ev.ShardID) - } + case leaderUpdated, nodeUnloaded, membershipChanged, nodeReady: e.engine.Cluster.Notify() case nodeDeleted: if ev.ReplicaID == e.engine.cfg.NodeID && e.engine.LogCache != nil { @@ -33,7 +28,6 @@ func (e *events) dispatchEvents() { if ev.ReplicaID == e.engine.cfg.NodeID && e.engine.LogCache != nil { e.engine.LogCache.LogCompacted(ev.ShardID) } - e.engine.Cluster.Notify() } } } diff --git a/storage/logreader/logreader.go b/storage/logreader/logreader.go index 42eb9ac4..6bbe2c58 100644 --- a/storage/logreader/logreader.go +++ b/storage/logreader/logreader.go @@ -41,20 +41,15 @@ func (l *Simple) QueryRaftLog(ctx context.Context, clusterID uint64, logRange dr } type ShardCache struct { - shardCache util.SyncMap[uint64, *shard] - ShardCacheSize int + shardCache *util.SyncMap[uint64, *shard] } func (l *ShardCache) NodeDeleted(shardID uint64) { l.shardCache.Delete(shardID) } -func (l *ShardCache) NodeReady(shardID uint64) { - l.shardCache.ComputeIfAbsent(shardID, func(shardId uint64) *shard { return &shard{cache: newCache(l.ShardCacheSize)} }) -} - func (l *ShardCache) LogCompacted(shardID uint64) { - l.shardCache.Store(shardID, &shard{cache: newCache(l.ShardCacheSize)}) + l.shardCache.Delete(shardID) } type Cached struct { @@ -156,3 +151,9 @@ func fixSize(entries []raftpb.Entry, maxSize uint64) []raftpb.Entry { } return entries } + +func NewShardCache(size int) *ShardCache { + return &ShardCache{shardCache: util.NewSyncMap(func(k uint64) *shard { + return &shard{cache: newCache(size)} + })} +} diff --git a/storage/logreader/logreader_test.go b/storage/logreader/logreader_test.go index 3c7fef7b..3e63d8c6 100644 --- a/storage/logreader/logreader_test.go +++ b/storage/logreader/logreader_test.go @@ -9,6 +9,7 @@ import ( serror "github.com/jamf/regatta/storage/errors" "github.com/jamf/regatta/util" + "github.com/jamf/regatta/util/iter" "github.com/lni/dragonboat/v4" "github.com/lni/dragonboat/v4/raftio" "github.com/lni/dragonboat/v4/raftpb" @@ -69,32 +70,30 @@ func TestCached_NodeDeleted(t *testing.T) { }{ { name: "remove existing cache shard", - fields: fields{ShardCache: &ShardCache{}}, + fields: fields{ShardCache: NewShardCache(1)}, args: args{info: raftio.NodeInfo{ ShardID: 1, ReplicaID: 1, }}, assert: func(t *testing.T, s *util.SyncMap[uint64, *shard]) { - _, ok := s.Load(uint64(1)) - require.False(t, ok, "unexpected cache shard") + require.False(t, iter.Contains(s.Keys(), uint64(1)), "unexpected cache shard") }, }, { name: "remove non-existent cache shard", - fields: fields{ShardCache: &ShardCache{}}, + fields: fields{ShardCache: NewShardCache(1)}, args: args{info: raftio.NodeInfo{ ShardID: 1, ReplicaID: 1, }}, assert: func(t *testing.T, s *util.SyncMap[uint64, *shard]) { - _, ok := s.Load(uint64(1)) - require.False(t, ok, "unexpected cache shard") + require.False(t, iter.Contains(s.Keys(), uint64(1)), "unexpected cache shard") }, }, { name: "remove existent cache shard from list", fields: fields{ShardCache: func() *ShardCache { - c := &ShardCache{} + c := NewShardCache(100) for i := 1; i <= 4; i++ { c.shardCache.Store(uint64(i), &shard{}) } @@ -105,8 +104,7 @@ func TestCached_NodeDeleted(t *testing.T) { ReplicaID: 1, }}, assert: func(t *testing.T, s *util.SyncMap[uint64, *shard]) { - _, ok := s.Load(uint64(2)) - require.False(t, ok, "unexpected cache shard") + require.False(t, iter.Contains(s.Keys(), uint64(2)), "unexpected cache shard") }, }, } @@ -114,77 +112,7 @@ func TestCached_NodeDeleted(t *testing.T) { t.Run(tt.name, func(t *testing.T) { l := &Cached{ShardCache: tt.fields.ShardCache} l.NodeDeleted(tt.args.info.ShardID) - tt.assert(t, &l.shardCache) - }) - } -} - -func TestCached_NodeReady(t *testing.T) { - type fields struct { - ShardCache *ShardCache - } - type args struct { - info raftio.NodeInfo - } - tests := []struct { - name string - args args - fields fields - assert func(*testing.T, *util.SyncMap[uint64, *shard]) - }{ - { - name: "add ready node", - fields: fields{ShardCache: &ShardCache{}}, - args: args{info: raftio.NodeInfo{ - ShardID: 1, - ReplicaID: 1, - }}, - assert: func(t *testing.T, s *util.SyncMap[uint64, *shard]) { - _, ok := s.Load(uint64(1)) - require.True(t, ok, "missing cache shard") - }, - }, - { - name: "add existing node", - fields: fields{ShardCache: func() *ShardCache { - c := &ShardCache{} - c.shardCache.Store(uint64(1), &shard{}) - return c - }()}, - args: args{info: raftio.NodeInfo{ - ShardID: 1, - ReplicaID: 1, - }}, - assert: func(t *testing.T, s *util.SyncMap[uint64, *shard]) { - _, ok := s.Load(uint64(1)) - require.True(t, ok, "missing cache shard") - }, - }, - { - name: "add ready node to list", - fields: fields{ShardCache: func() *ShardCache { - c := &ShardCache{} - c.shardCache.Store(uint64(1), &shard{}) - c.shardCache.Store(uint64(3), &shard{}) - c.shardCache.Store(uint64(5), &shard{}) - c.shardCache.Store(uint64(6), &shard{}) - return c - }()}, - args: args{info: raftio.NodeInfo{ - ShardID: 2, - ReplicaID: 1, - }}, - assert: func(t *testing.T, s *util.SyncMap[uint64, *shard]) { - _, ok := s.Load(uint64(2)) - require.True(t, ok, "missing cache shard") - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - l := &Cached{ShardCache: tt.fields.ShardCache} - l.NodeReady(tt.args.info.ShardID) - tt.assert(t, &l.shardCache) + tt.assert(t, l.shardCache) }) } } @@ -483,10 +411,9 @@ func TestCached_QueryRaftLog(t *testing.T) { tt.on(querier) } l := &Cached{ - ShardCache: &ShardCache{ShardCacheSize: tt.fields.ShardCacheSize}, + ShardCache: NewShardCache(tt.fields.ShardCacheSize), LogQuerier: querier, } - l.shardCache.ComputeIfAbsent(tt.args.clusterID, func(uint642 uint64) *shard { return &shard{cache: newCache(tt.fields.ShardCacheSize)} }) if len(tt.cacheContent) > 0 { v, _ := l.shardCache. Load(tt.args.clusterID) diff --git a/storage/queue.go b/storage/queue.go new file mode 100644 index 00000000..187a10aa --- /dev/null +++ b/storage/queue.go @@ -0,0 +1,117 @@ +// Copyright JAMF Software, LLC + +package storage + +import ( + "cmp" + "context" + "time" + + "github.com/jamf/regatta/util" + "github.com/jamf/regatta/util/heap" + "github.com/jamf/regatta/util/iter" +) + +type item struct { + ctx context.Context + table string + revision uint64 + waitCh chan error +} + +func (i *item) less(other *item) bool { + return cmp.Less(i.revision, other.revision) +} + +type notification struct { + table string + revision uint64 +} + +type IndexNotificationQueue struct { + items *util.SyncMap[string, *heap.Heap[*item]] + add chan *item + notif chan notification + closed chan struct{} +} + +func NewNotificationQueue() *IndexNotificationQueue { + return &IndexNotificationQueue{ + add: make(chan *item), + notif: make(chan notification), + closed: make(chan struct{}), + items: util.NewSyncMap(func(k string) *heap.Heap[*item] { + return heap.New((*item).less) + }), + } +} + +func (q *IndexNotificationQueue) Run() { + gc := time.NewTicker(time.Second) + defer gc.Stop() + for { + select { + case <-q.closed: + return + case <-gc.C: + iter.Consume(q.items.Values(), func(h *heap.Heap[*item]) { + l := h.Len() + for i := 0; i < l; i++ { + elem := h.Slice[i] + if elem.ctx.Err() != nil { + // Reorder + elem.revision = 0 + elem.waitCh <- elem.ctx.Err() + } + } + h.Fix(0) + for i := 0; i < l; i++ { + elem := h.Peek() + if elem.revision == 0 { + h.Pop() + } else { + break + } + } + }) + case it := <-q.add: + h, _ := q.items.Load(it.table) + h.Push(it) + case n := <-q.notif: + h, _ := q.items.Load(n.table) + l := h.Len() + for i := 0; i < l; i++ { + elem := h.Peek() + if elem.ctx.Err() != nil { + elem.waitCh <- elem.ctx.Err() + h.Pop() + } else if elem.revision <= n.revision { + close(elem.waitCh) + h.Pop() + } else { + break + } + } + } + } +} + +func (q *IndexNotificationQueue) Notify(table string, revision uint64) { + q.notif <- notification{table: table, revision: revision} +} + +func (q *IndexNotificationQueue) Close() error { + close(q.closed) + return nil +} + +func (q *IndexNotificationQueue) Add(ctx context.Context, table string, revision uint64) <-chan error { + ch := make(chan error, 1) + q.add <- &item{ + ctx: ctx, + table: table, + revision: revision, + waitCh: ch, + } + return ch +} diff --git a/storage/queue_test.go b/storage/queue_test.go new file mode 100644 index 00000000..74d2d1c6 --- /dev/null +++ b/storage/queue_test.go @@ -0,0 +1,71 @@ +// Copyright JAMF Software, LLC + +package storage + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNotificationQueue(t *testing.T) { + q := NewNotificationQueue() + go q.Run() + defer q.Close() + + w := q.Add(context.Background(), "foo", 1) + require.Empty(t, w) + q.Notify("foo", 10) + chanelClosed(t, w) + w = q.Add(context.Background(), "foo", 5) + require.Empty(t, w) + q.Notify("foo", 4) + require.Empty(t, w) + q.Notify("foo", 20) + chanelClosed(t, w) + + h, _ := q.items.Load("foo") + require.Zero(t, h.Len()) +} + +func TestNotificationQueueTimeout(t *testing.T) { + q := NewNotificationQueue() + go q.Run() + defer q.Close() + + ctx, cancel := context.WithCancel(context.TODO()) + q.Add(ctx, "foo", 1) + q.Add(ctx, "foo", 2) + q.Add(ctx, "foo", 3) + w := q.Add(ctx, "foo", 4) + cancel() + + require.Eventually(t, func() bool { + select { + case err := <-w: + return assert.Error(t, err, context.Canceled) + default: + return false + } + }, 5*time.Second, 100*time.Millisecond) + + require.Eventually(t, func() bool { + h, _ := q.items.Load("foo") + return assert.Zero(t, h.Len()) + }, 5*time.Second, 100*time.Millisecond) +} + +func chanelClosed(t *testing.T, w <-chan error) { + t.Helper() + require.Eventually(t, func() bool { + select { + case <-w: + return true + default: + return false + } + }, 10*time.Millisecond, time.Millisecond) +} diff --git a/storage/table/config.go b/storage/table/config.go index a9a7ef2e..a066eca4 100644 --- a/storage/table/config.go +++ b/storage/table/config.go @@ -103,7 +103,8 @@ type TableConfig struct { // TableCacheSize shared table cache size, the cache is used to hold handles to open SSTs. TableCacheSize int // RecoveryType the in-cluster snapshot recovery type. - RecoveryType SnapshotRecoveryType + RecoveryType SnapshotRecoveryType + AppliedIndexListener func(table string, rev uint64) } type MetaConfig struct { diff --git a/storage/table/fsm/fsm.go b/storage/table/fsm/fsm.go index 17d88ba2..175932d3 100644 --- a/storage/table/fsm/fsm.go +++ b/storage/table/fsm/fsm.go @@ -86,10 +86,13 @@ func (s *snapshotHeader) snapshotType() SnapshotRecoveryType { return SnapshotRecoveryType(s[6]) } -func New(tableName, stateMachineDir string, fs vfs.FS, blockCache *pebble.Cache, tableCache *pebble.TableCache, srt SnapshotRecoveryType) sm.CreateOnDiskStateMachineFunc { +func New(tableName, stateMachineDir string, fs vfs.FS, blockCache *pebble.Cache, tableCache *pebble.TableCache, srt SnapshotRecoveryType, af func(applied uint64)) sm.CreateOnDiskStateMachineFunc { if fs == nil { fs = vfs.Default } + if af == nil { + af = func(applied uint64) {} + } return func(clusterID uint64, nodeID uint64) sm.IOnDiskStateMachine { hostname, _ := os.Hostname() dbDirName := rp.GetNodeDBDirName(stateMachineDir, hostname, fmt.Sprintf("%s-%d", tableName, clusterID)) @@ -105,6 +108,7 @@ func New(tableName, stateMachineDir string, fs vfs.FS, blockCache *pebble.Cache, log: zap.S().Named("table").Named(tableName), metrics: newMetrics(tableName, clusterID), recoveryType: srt, + appliedFunc: af, } } } @@ -123,6 +127,7 @@ type FSM struct { tableCache *pebble.TableCache metrics *metrics recoveryType SnapshotRecoveryType + appliedFunc func(applied uint64) } func (p *FSM) Open(_ <-chan struct{}) (uint64, error) { @@ -178,6 +183,11 @@ func (p *FSM) Open(_ <-chan struct{}) (uint64, error) { return 0, err } p.metrics.applied.Store(idx) + p.appliedFunc(idx) + lx, _ := readLocalIndex(db, sysLeaderIndex) + if lx != 0 { + p.appliedFunc(lx) + } return idx, nil } @@ -301,6 +311,11 @@ func (p *FSM) Update(updates []sm.Entry) ([]sm.Entry, error) { } p.metrics.applied.Store(idx) + if ctx.leaderIndex != nil { + p.appliedFunc(*ctx.leaderIndex) + } else { + p.appliedFunc(idx) + } return updates, nil } diff --git a/storage/table/fsm/fsm_feature_test.go b/storage/table/fsm/fsm_feature_test.go index c8f83e63..838413ad 100644 --- a/storage/table/fsm/fsm_feature_test.go +++ b/storage/table/fsm/fsm_feature_test.go @@ -490,14 +490,15 @@ func generateFiles(t *testing.T, version int, inputCommands []*regattapb.Command func createTestFSM() (*FSM, error) { fsm := &FSM{ - fs: vfs.NewMem(), - clusterID: 1, - nodeID: 1, - tableName: "test", - dirname: "/tmp", - closed: false, - log: zap.NewNop().Sugar(), - metrics: newMetrics("test", 1), + fs: vfs.NewMem(), + clusterID: 1, + nodeID: 1, + tableName: "test", + dirname: "/tmp", + closed: false, + log: zap.NewNop().Sugar(), + metrics: newMetrics("test", 1), + appliedFunc: func(applied uint64) {}, } db, err := rp.OpenDB(fsm.dirname, rp.WithFS(fsm.fs)) diff --git a/storage/table/fsm/fsm_test.go b/storage/table/fsm/fsm_test.go index 7d39688d..f14e2e06 100644 --- a/storage/table/fsm/fsm_test.go +++ b/storage/table/fsm/fsm_test.go @@ -80,12 +80,13 @@ func TestSM_Open(t *testing.T) { t.Run(tt.name, func(t *testing.T) { r := require.New(t) p := &FSM{ - fs: vfs.NewMem(), - clusterID: tt.fields.clusterID, - nodeID: tt.fields.nodeID, - dirname: tt.fields.dirname, - log: zap.NewNop().Sugar(), - metrics: newMetrics(testTable, tt.fields.clusterID), + fs: vfs.NewMem(), + clusterID: tt.fields.clusterID, + nodeID: tt.fields.nodeID, + dirname: tt.fields.dirname, + log: zap.NewNop().Sugar(), + metrics: newMetrics(testTable, tt.fields.clusterID), + appliedFunc: func(applied uint64) {}, } _, err := p.Open(nil) if tt.wantErr { @@ -103,12 +104,13 @@ func TestFSM_ReOpen(t *testing.T) { fs := vfs.NewMem() const testIndex uint64 = 10 p := &FSM{ - fs: fs, - clusterID: 1, - nodeID: 1, - dirname: "/tmp/dir", - log: zap.NewNop().Sugar(), - metrics: newMetrics(testTable, 1), + fs: fs, + clusterID: 1, + nodeID: 1, + dirname: "/tmp/dir", + log: zap.NewNop().Sugar(), + metrics: newMetrics(testTable, 1), + appliedFunc: func(applied uint64) {}, } t.Log("open FSM") @@ -734,12 +736,13 @@ func equalResult(t *testing.T, want sm.Result, got sm.Result) { func emptySM() *FSM { p := &FSM{ - fs: vfs.NewMem(), - clusterID: 1, - nodeID: 1, - dirname: "/tmp/tst", - log: zap.NewNop().Sugar(), - metrics: newMetrics(testTable, 1), + fs: vfs.NewMem(), + clusterID: 1, + nodeID: 1, + dirname: "/tmp/tst", + log: zap.NewNop().Sugar(), + metrics: newMetrics(testTable, 1), + appliedFunc: func(applied uint64) {}, } _, err := p.Open(nil) if err != nil { @@ -777,12 +780,13 @@ func filledSM() *FSM { }) } p := &FSM{ - fs: vfs.NewMem(), - clusterID: 1, - nodeID: 1, - dirname: "/tmp/tst", - log: zap.NewNop().Sugar(), - metrics: newMetrics(testTable, 1), + fs: vfs.NewMem(), + clusterID: 1, + nodeID: 1, + dirname: "/tmp/tst", + log: zap.NewNop().Sugar(), + metrics: newMetrics(testTable, 1), + appliedFunc: func(applied uint64) {}, } _, err := p.Open(nil) if err != nil { @@ -811,12 +815,13 @@ func filledLargeValuesSM() *FSM { } } p := &FSM{ - fs: vfs.NewMem(), - clusterID: 1, - nodeID: 1, - dirname: "/tmp/tst", - log: zap.NewNop().Sugar(), - metrics: newMetrics(testTable, 1), + fs: vfs.NewMem(), + clusterID: 1, + nodeID: 1, + dirname: "/tmp/tst", + log: zap.NewNop().Sugar(), + metrics: newMetrics(testTable, 1), + appliedFunc: func(applied uint64) {}, } _, err := p.Open(nil) if err != nil { @@ -831,12 +836,13 @@ func filledLargeValuesSM() *FSM { func filledIndexOnlySM() *FSM { p := &FSM{ - fs: vfs.NewMem(), - clusterID: 1, - nodeID: 1, - dirname: "/tmp/tst", - log: zap.NewNop().Sugar(), - metrics: newMetrics(testTable, 1), + fs: vfs.NewMem(), + clusterID: 1, + nodeID: 1, + dirname: "/tmp/tst", + log: zap.NewNop().Sugar(), + metrics: newMetrics(testTable, 1), + appliedFunc: func(applied uint64) {}, } _, err := p.Open(nil) if err != nil { diff --git a/storage/table/fsm/metrics_test.go b/storage/table/fsm/metrics_test.go index 437c3d10..92b107fe 100644 --- a/storage/table/fsm/metrics_test.go +++ b/storage/table/fsm/metrics_test.go @@ -15,12 +15,13 @@ import ( func TestFSM_Metrics(t *testing.T) { p := &FSM{ - fs: vfs.NewMem(), - clusterID: 1, - nodeID: 1, - dirname: "/tmp", - log: zap.NewNop().Sugar(), - metrics: newMetrics(testTable, 1), + fs: vfs.NewMem(), + clusterID: 1, + nodeID: 1, + dirname: "/tmp", + log: zap.NewNop().Sugar(), + metrics: newMetrics(testTable, 1), + appliedFunc: func(applied uint64) {}, } _, _ = p.Open(nil) inFile, err := os.Open(path.Join("testdata", "metrics")) diff --git a/storage/table/manager.go b/storage/table/manager.go index f8d13997..badedf62 100644 --- a/storage/table/manager.go +++ b/storage/table/manager.go @@ -495,14 +495,22 @@ func (m *Manager) startTable(name string, id uint64) error { return m.nh.StartOnDiskReplica( map[uint64]dragonboat.Target{}, false, - fsm.New(name, m.cfg.Table.DataDir, m.cfg.Table.FS, m.blockCache, m.tableCache, fsm.SnapshotRecoveryType(m.cfg.Table.RecoveryType)), + fsm.New(name, m.cfg.Table.DataDir, m.cfg.Table.FS, m.blockCache, m.tableCache, fsm.SnapshotRecoveryType(m.cfg.Table.RecoveryType), func(applied uint64) { + if m.cfg.Table.AppliedIndexListener != nil { + m.cfg.Table.AppliedIndexListener(name, applied) + } + }), tableRaftConfig(m.cfg.NodeID, id, m.cfg.Table), ) } return m.nh.StartOnDiskReplica( m.members, false, - fsm.New(name, m.cfg.Table.DataDir, m.cfg.Table.FS, m.blockCache, m.tableCache, fsm.SnapshotRecoveryType(m.cfg.Table.RecoveryType)), + fsm.New(name, m.cfg.Table.DataDir, m.cfg.Table.FS, m.blockCache, m.tableCache, fsm.SnapshotRecoveryType(m.cfg.Table.RecoveryType), func(applied uint64) { + if m.cfg.Table.AppliedIndexListener != nil { + m.cfg.Table.AppliedIndexListener(name, applied) + } + }), tableRaftConfig(m.cfg.NodeID, id, m.cfg.Table), ) } diff --git a/util/heap/heap.go b/util/heap/heap.go new file mode 100644 index 00000000..85cbf122 --- /dev/null +++ b/util/heap/heap.go @@ -0,0 +1,109 @@ +// Copyright JAMF Software, LLC + +package heap + +type Heap[E any] struct { + Slice []E + Less func(E, E) bool +} + +func New[E any](less func(E, E) bool, items ...E) *Heap[E] { + h := &Heap[E]{ + Slice: items, + Less: less, + } + n := len(items) + for i := n/2 - 1; i >= 0; i-- { + h.down(i, n) + } + return h +} + +func (h *Heap[E]) Peek() E { + if len(h.Slice) == 0 { + panic("empty slice") + } + return h.Slice[0] +} + +func (h *Heap[E]) Push(item E) { + h.Slice = append(h.Slice, item) + h.up(len(h.Slice) - 1) +} + +func (h *Heap[E]) Pop() E { + if len(h.Slice) == 0 { + panic("empty slice") + } + n := len(h.Slice) - 1 + h.swap(0, n) + h.down(0, n) + return h.zpop() +} + +func (h *Heap[E]) Remove(i int) E { + n := len(h.Slice) - 1 + if n != i { + h.swap(i, n) + if !h.down(i, n) { + h.up(i) + } + } + return h.zpop() +} + +func (h *Heap[E]) Len() int { + return len(h.Slice) +} + +func (h *Heap[E]) Fix(i int) { + if i < 0 { + return + } + if !h.down(i, len(h.Slice)) { + h.up(i) + } +} + +func (h *Heap[E]) up(j int) { + for { + i := (j - 1) / 2 // parent + if i == j || !h.Less(h.Slice[j], h.Slice[i]) { + break + } + h.swap(i, j) + j = i + } +} + +func (h *Heap[E]) down(i0 int, n int) bool { + i := i0 + for { + j1 := 2*i + 1 + if j1 >= n || j1 < 0 { // j1 < 0 after int overflow + break + } + j := j1 // left child + if j2 := j1 + 1; j2 < n && h.Less(h.Slice[j2], h.Slice[j1]) { + j = j2 // = 2*i + 2 // right child + } + if !h.Less(h.Slice[j], h.Slice[i]) { + break + } + h.swap(i, j) + i = j + } + return i > i0 +} + +func (h *Heap[E]) swap(i, j int) { + h.Slice[i], h.Slice[j] = h.Slice[j], h.Slice[i] +} + +func (h *Heap[E]) zpop() E { + var zero E + e0 := h.Slice[len(h.Slice)-1] + h.Slice[len(h.Slice)-1] = zero + h.Slice = h.Slice[:len(h.Slice)-1] + return e0 +} diff --git a/util/heap/heap_test.go b/util/heap/heap_test.go new file mode 100644 index 00000000..47e63a6b --- /dev/null +++ b/util/heap/heap_test.go @@ -0,0 +1,195 @@ +// Copyright JAMF Software, LLC + +package heap + +import ( + "math/rand" + "testing" +) + +func newHeap() *Heap[int] { + return &Heap[int]{ + Less: func(i, j int) bool { + return i < j + }, + } +} + +func verify[T any](t *testing.T, h *Heap[T], i int) { + t.Helper() + n := h.Len() + j1 := 2*i + 1 + j2 := 2*i + 2 + if j1 < n { + if h.Less(h.Slice[j1], h.Slice[i]) { + t.Errorf("heap invariant invalidated [%d] = %v > [%d] = %v", i, h.Slice[i], j1, h.Slice[j1]) + return + } + verify(t, h, j1) + } + if j2 < n { + if h.Less(h.Slice[j2], h.Slice[i]) { + t.Errorf("heap invariant invalidated [%d] = %v > [%d] = %v", i, h.Slice[i], j1, h.Slice[j2]) + return + } + verify(t, h, j2) + } +} + +func TestInit0(t *testing.T) { + h := newHeap() + for i := 20; i > 0; i-- { + h.Push(0) // all elements are the same + } + verify(t, h, 0) + + for i := 1; h.Len() > 0; i++ { + x := h.Pop() + verify(t, h, 0) + if x != 0 { + t.Errorf("%d.th pop got %d; want %d", i, x, 0) + } + } +} + +func TestInit1(t *testing.T) { + h := newHeap() + for i := 20; i > 0; i-- { + h.Push(i) // all elements are different + } + verify(t, h, 0) + + for i := 1; h.Len() > 0; i++ { + x := h.Pop() + verify(t, h, 0) + if x != i { + t.Errorf("%d.th pop got %d; want %d", i, x, i) + } + } +} + +func Test(t *testing.T) { + h := newHeap() + verify(t, h, 0) + + for i := 20; i > 10; i-- { + h.Push(i) + } + verify(t, h, 0) + + for i := 10; i > 0; i-- { + h.Push(i) + verify(t, h, 0) + } + + for i := 1; h.Len() > 0; i++ { + x := h.Pop() + if i < 20 { + h.Push(20 + i) + } + verify(t, h, 0) + if x != i { + t.Errorf("%d.th pop got %d; want %d", i, x, i) + } + } +} + +func TestRemove0(t *testing.T) { + h := newHeap() + for i := 0; i < 10; i++ { + h.Push(i) + } + verify(t, h, 0) + + for h.Len() > 0 { + i := h.Len() - 1 + x := h.Remove(i) + if x != i { + t.Errorf("Remove(%d) got %d; want %d", i, x, i) + } + verify(t, h, 0) + } +} + +func TestRemove1(t *testing.T) { + h := newHeap() + for i := 0; i < 10; i++ { + h.Push(i) + } + verify(t, h, 0) + + for i := 0; h.Len() > 0; i++ { + x := h.Remove(0) + if x != i { + t.Errorf("Remove(0) got %d; want %d", x, i) + } + verify(t, h, 0) + } +} + +func TestRemove2(t *testing.T) { + N := 10 + + h := newHeap() + for i := 0; i < N; i++ { + h.Push(i) + } + verify(t, h, 0) + + m := make(map[int]bool) + for h.Len() > 0 { + m[h.Remove((h.Len()-1)/2)] = true + verify(t, h, 0) + } + + if len(m) != N { + t.Errorf("len(m) = %d; want %d", len(m), N) + } + for i := 0; i < len(m); i++ { + if !m[i] { + t.Errorf("m[%d] doesn't exist", i) + } + } +} + +func BenchmarkDup(b *testing.B) { + const n = 10000 + h := newHeap() + h.Slice = make([]int, 0, n) + for i := 0; i < b.N; i++ { + for j := 0; j < n; j++ { + h.Push(0) // all elements are the same + } + for h.Len() > 0 { + h.Pop() + } + } +} + +func TestFix(t *testing.T) { + h := newHeap() + verify(t, h, 0) + + for i := 200; i > 0; i -= 10 { + h.Push(i) + } + verify(t, h, 0) + + if h.Slice[0] != 10 { + t.Fatalf("Expected head to be 10, was %d", h.Slice[0]) + } + h.Slice[0] = 210 + h.Fix(0) + verify(t, h, 0) + + for i := 100; i > 0; i-- { + elem := rand.Intn(h.Len()) + if i&1 == 0 { + h.Slice[elem] *= 2 + } else { + h.Slice[elem] /= 2 + } + h.Fix(elem) + verify(t, h, 0) + } +} diff --git a/util/iter/iter.go b/util/iter/iter.go index 0240256c..b081a4a7 100644 --- a/util/iter/iter.go +++ b/util/iter/iter.go @@ -82,3 +82,15 @@ func Pull[T any](seq Seq[T]) (iter func() (T, bool), stop func()) { <-yield }) } + +func Contains[T comparable](seq Seq[T], item T) bool { + found := false + seq(func(t T) bool { + if t == item { + found = true + return false + } + return true + }) + return found +} diff --git a/util/sync.go b/util/sync.go index 91f7d817..7e301d1f 100644 --- a/util/sync.go +++ b/util/sync.go @@ -4,38 +4,67 @@ package util import ( "sync" + + "github.com/jamf/regatta/util/iter" ) type SyncMap[K comparable, V any] struct { - m map[K]V - mtx sync.RWMutex + m map[K]V + mtx sync.RWMutex + defaultFunc func(K) V +} + +func NewSyncMap[K comparable, V any](defaulter func(K) V) *SyncMap[K, V] { + return &SyncMap[K, V]{m: make(map[K]V), defaultFunc: defaulter} +} + +func (s *SyncMap[K, V]) Keys() iter.Seq[K] { + return func(yield func(K) bool) { + s.mtx.RLock() + defer s.mtx.RUnlock() + for k := range s.m { + if !yield(k) { + break + } + } + } +} + +func (s *SyncMap[K, V]) Values() iter.Seq[V] { + return func(yield func(V) bool) { + s.mtx.RLock() + defer s.mtx.RUnlock() + for _, v := range s.m { + if !yield(v) { + break + } + } + } } func (s *SyncMap[K, V]) Load(key K) (V, bool) { s.mtx.RLock() - defer s.mtx.RUnlock() - if s.m == nil { - return *new(V), false - } v, ok := s.m[key] + if !ok && s.defaultFunc != nil { + s.mtx.RUnlock() + s.mtx.Lock() + defer s.mtx.Unlock() + s.m[key] = s.defaultFunc(key) + return s.m[key], true + } + s.mtx.RUnlock() return v, ok } func (s *SyncMap[K, V]) Store(key K, val V) { s.mtx.Lock() defer s.mtx.Unlock() - if s.m == nil { - s.m = make(map[K]V) - } s.m[key] = val } func (s *SyncMap[K, V]) ComputeIfAbsent(key K, valFunc func(K) V) V { s.mtx.Lock() defer s.mtx.Unlock() - if s.m == nil { - s.m = make(map[K]V) - } v, ok := s.m[key] if !ok { v = valFunc(key) diff --git a/util/sync_test.go b/util/sync_test.go index 56cd0713..a1e3475a 100644 --- a/util/sync_test.go +++ b/util/sync_test.go @@ -23,7 +23,8 @@ func TestSyncMap_ComputeIfAbsent(t *testing.T) { want string }{ { - name: "compute missing key", + name: "compute missing key", + fields: fields{m: map[string]string{}}, args: args{ key: "key", valFunc: func(s string) string { @@ -151,7 +152,8 @@ func TestSyncMap_Store(t *testing.T) { assert func(*testing.T, *SyncMap[string, string]) }{ { - name: "store into empty map", + name: "store into empty map", + fields: fields{m: map[string]string{}}, args: args{ key: "key", val: "value",