diff --git a/p2p/exchange_test.go b/p2p/exchange_test.go index f126e076..b0fb299c 100644 --- a/p2p/exchange_test.go +++ b/p2p/exchange_test.go @@ -18,11 +18,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/celestiaorg/go-libp2p-messenger/serde" - "github.com/celestiaorg/go-header" "github.com/celestiaorg/go-header/headertest" p2p_pb "github.com/celestiaorg/go-header/p2p/pb" + "github.com/celestiaorg/go-libp2p-messenger/serde" ) const networkID = "private" diff --git a/sync/sync.go b/sync/sync.go index 31aaf226..ce66b6a9 100644 --- a/sync/sync.go +++ b/sync/sync.go @@ -44,6 +44,8 @@ type Syncer[H header.Header] struct { triggerSync chan struct{} // pending keeps ranges of valid new network headers awaiting to be appended to store pending ranges[H] + // incomingMu ensures only one incoming network head candidate is processed at the time + incomingMu sync.Mutex // controls lifecycle for syncLoop ctx context.Context diff --git a/sync/sync_getter.go b/sync/sync_getter.go index 0f32dc6e..5f24de44 100644 --- a/sync/sync_getter.go +++ b/sync/sync_getter.go @@ -3,29 +3,50 @@ package sync import ( "context" "sync" + "sync/atomic" "github.com/celestiaorg/go-header" ) // syncGetter is a Getter wrapper that ensure only one Head call happens at the time type syncGetter[H header.Header] struct { + getterLk sync.RWMutex + isGetterLk atomic.Bool header.Getter[H] - - headLk sync.RWMutex - headErr error - head H } -func (se *syncGetter[H]) Head(ctx context.Context) (H, error) { - // the lock construction here ensures only one routine calling Head at a time +// Lock locks the getter for single user. +// Reports 'true' if the lock was held by the current routine. +// Does not require unlocking on 'false'. +func (sg *syncGetter[H]) Lock() bool { + // the lock construction here ensures only one routine is freed at a time // while others wait via Rlock - if !se.headLk.TryLock() { - se.headLk.RLock() - defer se.headLk.RUnlock() - return se.head, se.headErr + locked := sg.getterLk.TryLock() + if !locked { + sg.getterLk.RLock() + defer sg.getterLk.RUnlock() + return false } - defer se.headLk.Unlock() + sg.isGetterLk.Store(locked) + return locked +} + +// Unlock unlocks the getter. +func (sg *syncGetter[H]) Unlock() { + sg.checkLock("Unlock without preceding Lock on syncGetter") + sg.getterLk.Unlock() + sg.isGetterLk.Store(false) +} - se.head, se.headErr = se.Getter.Head(ctx) - return se.head, se.headErr +// Head must be called with held Lock. +func (sg *syncGetter[H]) Head(ctx context.Context) (H, error) { + sg.checkLock("Head without preceding Lock on syncGetter") + return sg.Getter.Head(ctx) +} + +// checkLock ensures api safety +func (sg *syncGetter[H]) checkLock(msg string) { + if !sg.isGetterLk.Load() { + panic(msg) + } } diff --git a/sync/sync_getter_test.go b/sync/sync_getter_test.go index 8070512c..06fcd7ad 100644 --- a/sync/sync_getter_test.go +++ b/sync/sync_getter_test.go @@ -26,6 +26,10 @@ func TestSyncGetterHead(t *testing.T) { wg.Add(1) go func() { defer wg.Done() + if !sex.Lock() { + return + } + defer sex.Unlock() h, err := sex.Head(ctx) if h != nil || err != errFakeHead { t.Fail() diff --git a/sync/sync_head.go b/sync/sync_head.go index d562c46a..bcc5749e 100644 --- a/sync/sync_head.go +++ b/sync/sync_head.go @@ -31,6 +31,15 @@ func (s *Syncer[H]) Head(ctx context.Context) (H, error) { // * If now >= TNH && now <= TNH + (THP) header propagation time // * Wait for header to arrive instead of requesting it // * This way we don't request as we know the new network header arrives exactly + // + // single-flight protection + // ensure only one Head is requested at the time + if !s.getter.Lock() { + // means that other routine held the lock and set the subjective head for us, + // so just recursively get it + return s.Head(ctx) + } + defer s.getter.Unlock() netHead, err := s.getter.Head(ctx) if err != nil { return netHead, err @@ -68,6 +77,14 @@ func (s *Syncer[H]) subjectiveHead(ctx context.Context) (H, error) { } // otherwise, request head from a trusted peer log.Infow("stored head header expired", "height", storeHead.Height()) + // single-flight protection + // ensure only one Head is requested at the time + if !s.getter.Lock() { + // means that other routine held the lock and set the subjective head for us, + // so just recursively get it + return s.subjectiveHead(ctx) + } + defer s.getter.Unlock() trustHead, err := s.getter.Head(ctx) if err != nil { return trustHead, err @@ -119,6 +136,9 @@ func (s *Syncer[H]) setSubjectiveHead(ctx context.Context, netHead H) { // incomingNetworkHead processes new potential network headers. // If the header valid, sets as new subjective header. func (s *Syncer[H]) incomingNetworkHead(ctx context.Context, netHead H) pubsub.ValidationResult { + // ensure there is no racing between network head candidates + s.incomingMu.Lock() + defer s.incomingMu.Unlock() // first of all, check the validity of the netHead res := s.validateHead(ctx, netHead) if res == pubsub.ValidationAccept { diff --git a/sync/sync_head_test.go b/sync/sync_head_test.go new file mode 100644 index 00000000..6d464ca4 --- /dev/null +++ b/sync/sync_head_test.go @@ -0,0 +1,46 @@ +package sync + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + pubsub "github.com/libp2p/go-libp2p-pubsub" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/go-header/headertest" +) + +func TestSyncer_incomingNetworkHeadRaces(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) + + suite := headertest.NewTestSuite(t) + store := headertest.NewStore[*headertest.DummyHeader](t, suite, 1) + syncer, err := NewSyncer[*headertest.DummyHeader]( + store, + store, + headertest.NewDummySubscriber(), + ) + require.NoError(t, err) + + incoming := suite.NextHeader() + + var hits atomic.Uint32 + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if syncer.incomingNetworkHead(ctx, incoming) == pubsub.ValidationAccept { + hits.Add(1) + } + }() + } + + wg.Wait() + assert.EqualValues(t, 1, hits.Load()) +}