diff --git a/pkg/p2p/streamtest/streamtest.go b/pkg/p2p/streamtest/streamtest.go index 825c323a759..942e99aef57 100644 --- a/pkg/p2p/streamtest/streamtest.go +++ b/pkg/p2p/streamtest/streamtest.go @@ -101,6 +101,7 @@ func (r *Recorder) NewStream(ctx context.Context, addr swarm.Address, h p2p.Head return nil, err } } + recordIn := newRecord() recordOut := newRecord() streamOut := newStream(recordIn, recordOut) diff --git a/pkg/pushsync/pushsync.go b/pkg/pushsync/pushsync.go index 7fd0c434752..4be0f8bff2b 100644 --- a/pkg/pushsync/pushsync.go +++ b/pkg/pushsync/pushsync.go @@ -445,15 +445,6 @@ func (ps *PushSync) pushPeer(ctx context.Context, peer swarm.Address, ch swarm.C ps.metrics.TotalSent.Inc() - // if you manage to get a tag, just increment the respective counter - t, err := ps.tagger.Get(ch.TagID()) - if err == nil && t != nil { - err = t.Inc(tags.StateSent) - if err != nil { - return nil, true, fmt.Errorf("tag %d increment: %v", ch.TagID(), err) - } - } - var receipt pb.Receipt if err := r.ReadMsgWithContext(ctx, &receipt); err != nil { _ = streamer.Reset() @@ -470,6 +461,15 @@ func (ps *PushSync) pushPeer(ctx context.Context, peer swarm.Address, ch swarm.C return nil, true, err } + // if you manage to get a tag, just increment the respective counter + t, err := ps.tagger.Get(ch.TagID()) + if err == nil && t != nil { + err = t.Inc(tags.StateSent) + if err != nil { + return nil, true, fmt.Errorf("tag %d increment: %v", ch.TagID(), err) + } + } + return &receipt, true, nil } diff --git a/pkg/pushsync/pushsync_test.go b/pkg/pushsync/pushsync_test.go index 0e7d696ba8f..14fbe035d2d 100644 --- a/pkg/pushsync/pushsync_test.go +++ b/pkg/pushsync/pushsync_test.go @@ -488,7 +488,7 @@ func TestPushChunkToNextClosest(t *testing.T) { if err != nil { t.Fatal(err) } - if ta2.Get(tags.StateSent) != 2 { + if ta2.Get(tags.StateSent) != 1 { t.Fatalf("tags error") } @@ -940,10 +940,8 @@ func TestPushChunkToClosestSkipFailed(t *testing.T) { ) defer storerPeer4.Close() - var ( - fail = true - lock sync.Mutex - ) + triggerCount := 0 + var lock sync.Mutex recorder := streamtest.New( streamtest.WithPeerProtocols( @@ -954,15 +952,25 @@ func TestPushChunkToClosestSkipFailed(t *testing.T) { peer4.String(): psPeer4.Protocol(), }, ), - streamtest.WithStreamError( - func(addr swarm.Address, _, _, _ string) error { - lock.Lock() - defer lock.Unlock() - if fail && addr.String() != peer4.String() { - return errors.New("peer not reachable") - } + streamtest.WithMiddlewares( + func(h p2p.HandlerFunc) p2p.HandlerFunc { + return func(ctx context.Context, peer p2p.Peer, stream p2p.Stream) error { + lock.Lock() + defer lock.Unlock() - return nil + if triggerCount < 9 { + triggerCount++ + stream.Close() + return errors.New("fmt") + } + + if err := h(ctx, peer, stream); err != nil { + return err + } + // close stream after all previous middlewares wrote to it + // so that the receiving peer can get all the post messages + return stream.Close() + } }, ), streamtest.WithBaseAddr(pivotNode),