From bb096962ca5a2d4abe328a1e8246d558397ba598 Mon Sep 17 00:00:00 2001 From: Rod Vagg Date: Fri, 12 May 2023 08:41:54 +1000 Subject: [PATCH] fix(http): refactor MockRoundTripper (#229) --- pkg/internal/testutil/mockroundtripper.go | 299 +++++++++++++++ pkg/retriever/httpretriever_test.go | 425 +++------------------- 2 files changed, 353 insertions(+), 371 deletions(-) create mode 100644 pkg/internal/testutil/mockroundtripper.go diff --git a/pkg/internal/testutil/mockroundtripper.go b/pkg/internal/testutil/mockroundtripper.go new file mode 100644 index 00000000..e56772c3 --- /dev/null +++ b/pkg/internal/testutil/mockroundtripper.go @@ -0,0 +1,299 @@ +package testutil + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "strings" + "sync" + "testing" + "time" + + "github.com/benbjohnson/clock" + "github.com/filecoin-project/lassie/pkg/types" + "github.com/ipfs/go-cid" + "github.com/ipld/go-car/v2" + "github.com/ipld/go-car/v2/storage" + dagpb "github.com/ipld/go-codec-dagpb" + "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" + "github.com/ipld/go-ipld-prime/linking" + cidlink "github.com/ipld/go-ipld-prime/linking/cid" + "github.com/ipld/go-ipld-prime/node/basicnode" + "github.com/ipld/go-ipld-prime/traversal" + "github.com/ipld/go-ipld-prime/traversal/selector" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/stretchr/testify/require" +) + +type MockRoundTripRemote struct { + Peer peer.AddrInfo + LinkSystem linking.LinkSystem + Selector ipld.Node + RespondAt time.Time + Malformed bool +} + +type MockRoundTripper struct { + t *testing.T + ctx context.Context + clock *clock.Mock + remoteBlockDuration time.Duration + expectedPath map[cid.Cid]string + expectedScope map[cid.Cid]types.CarScope + remotes map[cid.Cid][]MockRoundTripRemote + startsCh chan peer.ID + statsCh chan RemoteStats + endsCh chan peer.ID +} + +var _ http.RoundTripper = (*MockRoundTripper)(nil) +var _ VerifierClient = (*MockRoundTripper)(nil) + +func NewMockRoundTripper( + t *testing.T, + ctx context.Context, + clock *clock.Mock, + remoteBlockDuration time.Duration, + expectedPath map[cid.Cid]string, + expectedScope map[cid.Cid]types.CarScope, + remotes map[cid.Cid][]MockRoundTripRemote, +) *MockRoundTripper { + return &MockRoundTripper{ + t, + ctx, + clock, + remoteBlockDuration, + expectedPath, + expectedScope, + remotes, + make(chan peer.ID, 32), + make(chan RemoteStats, 32), + make(chan peer.ID, 32), + } +} + +func (mrt *MockRoundTripper) getRemote(cid cid.Cid, maddr string) MockRoundTripRemote { + remotes, ok := mrt.remotes[cid] + require.True(mrt.t, ok) + for _, remote := range remotes { + if remote.Peer.Addrs[0].String() == maddr { + return remote + } + } + mrt.t.Fatal("remote not found") + return MockRoundTripRemote{} +} + +func (mrt *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + us := strings.Split(req.URL.Path, "/") + require.True(mrt.t, len(us) > 2) + require.Equal(mrt.t, us[1], "ipfs") + root, err := cid.Parse(us[2]) + require.NoError(mrt.t, err) + path := strings.Join(us[3:], "/") + expectedPath, ok := mrt.expectedPath[root] + if !ok { + require.Equal(mrt.t, path, "") + } else { + require.Equal(mrt.t, path, expectedPath) + } + expectedScope := types.CarScopeAll + if scope, ok := mrt.expectedScope[root]; ok { + expectedScope = scope + } + require.Equal(mrt.t, req.URL.RawQuery, fmt.Sprintf("car-scope=%s", expectedScope)) + ip := req.URL.Hostname() + port := req.URL.Port() + maddr := fmt.Sprintf("/ip4/%s/tcp/%s/http", ip, port) + remote := mrt.getRemote(root, maddr) + mrt.startsCh <- remote.Peer.ID + + sleepFor := mrt.clock.Until(remote.RespondAt) + if sleepFor > 0 { + select { + case <-mrt.ctx.Done(): + return nil, mrt.ctx.Err() + case <-mrt.clock.After(sleepFor): + } + } + + return &http.Response{ + StatusCode: http.StatusOK, + Body: newDeferredBody(mrt, remote, root), + }, nil +} + +func (mrt *MockRoundTripper) VerifyConnectionsReceived(ctx context.Context, t *testing.T, afterStart time.Duration, expectedConnections []peer.ID) { + if len(expectedConnections) > 0 { + require.FailNowf(t, "unexpected ConnectionsReceived", "@ %s", afterStart) + } +} + +func (mrt *MockRoundTripper) VerifyRetrievalsReceived(ctx context.Context, t *testing.T, afterStart time.Duration, expectedRetrievals []peer.ID) { + retrievals := make([]peer.ID, 0, len(expectedRetrievals)) + for i := 0; i < len(expectedRetrievals); i++ { + select { + case retrieval := <-mrt.startsCh: + retrievals = append(retrievals, retrieval) + case <-ctx.Done(): + require.FailNowf(t, "failed to receive expected retrievals", "expected %d, received %d @ %s", len(expectedRetrievals), i, afterStart) + } + } + require.ElementsMatch(t, expectedRetrievals, retrievals) +} + +func (mrt *MockRoundTripper) VerifyRetrievalsServed(ctx context.Context, t *testing.T, afterStart time.Duration, expectedServed []RemoteStats) { + remoteStats := make([]RemoteStats, 0, len(expectedServed)) + for i := 0; i < len(expectedServed); i++ { + select { + case stats := <-mrt.statsCh: + remoteStats = append(remoteStats, stats) + case <-ctx.Done(): + require.FailNowf(t, "failed to receive expected served", "expected %d, received %d @ %s", len(expectedServed), i, afterStart) + } + } + require.ElementsMatch(t, expectedServed, remoteStats) +} + +func (mrt *MockRoundTripper) VerifyRetrievalsCompleted(ctx context.Context, t *testing.T, afterStart time.Duration, expectedRetrievals []peer.ID) { + retrievals := make([]peer.ID, 0, len(expectedRetrievals)) + for i := 0; i < len(expectedRetrievals); i++ { + select { + case retrieval := <-mrt.endsCh: + retrievals = append(retrievals, retrieval) + case <-ctx.Done(): + require.FailNowf(t, "failed to complete expected retrievals", "expected %d, received %d @ %s", len(expectedRetrievals), i, afterStart) + } + } + require.ElementsMatch(t, expectedRetrievals, retrievals) +} + +// deferredBody is simply a Reader that lazily starts a CAR writer on the first +// Read call. +type deferredBody struct { + mrt *MockRoundTripper + remote MockRoundTripRemote + root cid.Cid + + r io.ReadCloser + once sync.Once +} + +func newDeferredBody(mrt *MockRoundTripper, remote MockRoundTripRemote, root cid.Cid) *deferredBody { + return &deferredBody{ + mrt: mrt, + remote: remote, + root: root, + } +} + +var _ io.ReadCloser = (*deferredBody)(nil) + +func (d *deferredBody) makeBody() io.ReadCloser { + carR, carW := io.Pipe() + req := require.New(d.mrt.t) + + sel, err := selector.CompileSelector(d.remote.Selector) + req.NoError(err) + + go func() { + stats := RemoteStats{ + Peer: d.remote.Peer.ID, + Root: d.root, + Blocks: make([]cid.Cid, 0), + } + + defer func() { + d.mrt.statsCh <- stats + req.NoError(carW.Close()) + }() + + if d.remote.Malformed { + carW.Write([]byte("nope, this is not what you're looking for")) + return + } + + // instantiating this writes a CARv1 header and waits for more Put()s + carWriter, err := storage.NewWritable(carW, []cid.Cid{d.root}, car.WriteAsCarV1(true), car.AllowDuplicatePuts(false)) + req.NoError(err) + + // intercept the StorageReadOpener of the LinkSystem so that for each + // read that the traverser performs, we take that block and Put() it + // to the CARv1 writer. + lsys := d.remote.LinkSystem + originalSRO := lsys.StorageReadOpener + lsys.StorageReadOpener = func(lc linking.LinkContext, lnk datamodel.Link) (io.Reader, error) { + r, err := originalSRO(lc, lnk) + if err != nil { + return nil, err + } + byts, err := io.ReadAll(r) + if err != nil { + return nil, err + } + err = carWriter.Put(d.mrt.ctx, lnk.(cidlink.Link).Cid.KeyString(), byts) + req.NoError(err) + stats.Blocks = append(stats.Blocks, lnk.(cidlink.Link).Cid) + stats.ByteCount += uint64(len(byts)) // only the length of the bytes, not the rest of the CAR infrastructure + + // ensure there is blockDuration between each block send + sendAt := d.remote.RespondAt.Add(d.mrt.remoteBlockDuration * time.Duration(len(stats.Blocks))) + if d.mrt.clock.Until(sendAt) > 0 { + select { + case <-d.mrt.ctx.Done(): + return nil, d.mrt.ctx.Err() + case <-d.mrt.clock.After(d.mrt.clock.Until(sendAt)): + time.Sleep(1 * time.Millisecond) // let em goroutines breathe + } + } + return bytes.NewReader(byts), nil + } + + // load and register the root link so it's pushed to the CAR since + // the traverser won't load it (we feed the traverser the rood _node_ + // not the link) + var proto datamodel.NodePrototype = basicnode.Prototype.Any + if d.root.Prefix().Codec == cid.DagProtobuf { + proto = dagpb.Type.PBNode + } + rootNode, err := lsys.Load(linking.LinkContext{Ctx: d.mrt.ctx}, cidlink.Link{Cid: d.root}, proto) + if err != nil { + stats.Err = struct{}{} + } else { + // begin traversal + err := traversal.Progress{ + Cfg: &traversal.Config{ + Ctx: d.mrt.ctx, + LinkSystem: lsys, + LinkTargetNodePrototypeChooser: dagpb.AddSupportToChooser(basicnode.Chooser), + }, + }.WalkAdv(rootNode, sel, func(p traversal.Progress, n datamodel.Node, vr traversal.VisitReason) error { return nil }) + if err != nil { + stats.Err = struct{}{} + } + } + }() + + return carR +} + +func (d *deferredBody) Read(p []byte) (n int, err error) { + d.once.Do(func() { + d.r = d.makeBody() + }) + n, err = d.r.Read(p) + if err == io.EOF { + d.mrt.endsCh <- d.remote.Peer.ID + } + return n, err +} + +func (d *deferredBody) Close() error { + if d.r != nil { + return d.r.Close() + } + return nil +} diff --git a/pkg/retriever/httpretriever_test.go b/pkg/retriever/httpretriever_test.go index 34334a38..f14e2d7a 100644 --- a/pkg/retriever/httpretriever_test.go +++ b/pkg/retriever/httpretriever_test.go @@ -1,13 +1,9 @@ package retriever_test import ( - "bytes" "context" - "fmt" "io" "net/http" - "strings" - "sync" "testing" "time" @@ -20,31 +16,16 @@ import ( "github.com/google/uuid" "github.com/ipfs/go-cid" gstestutil "github.com/ipfs/go-graphsync/testutil" - "github.com/ipld/go-car/v2" - "github.com/ipld/go-car/v2/storage" - dagpb "github.com/ipld/go-codec-dagpb" - "github.com/ipld/go-ipld-prime" "github.com/ipld/go-ipld-prime/datamodel" "github.com/ipld/go-ipld-prime/linking" cidlink "github.com/ipld/go-ipld-prime/linking/cid" - "github.com/ipld/go-ipld-prime/node/basicnode" "github.com/ipld/go-ipld-prime/storage/memstore" - "github.com/ipld/go-ipld-prime/traversal" - "github.com/ipld/go-ipld-prime/traversal/selector" selectorparse "github.com/ipld/go-ipld-prime/traversal/selector/parse" "github.com/ipni/go-libipni/metadata" "github.com/libp2p/go-libp2p/core/peer" "github.com/stretchr/testify/require" ) -type httpRemote struct { - peer peer.AddrInfo - lsys *linking.LinkSystem - sel ipld.Node - respondAt time.Time - malformed bool -} - func TestHTTPRetriever(t *testing.T) { ctx := context.Background() @@ -81,7 +62,7 @@ func TestHTTPRetriever(t *testing.T) { requests map[cid.Cid]types.RetrievalID requestPath map[cid.Cid]string requestScope map[cid.Cid]types.CarScope - remotes map[cid.Cid][]httpRemote + remotes map[cid.Cid][]testutil.MockRoundTripRemote expectedStats map[cid.Cid]*types.RetrievalStats expectedErrors map[cid.Cid]struct{} expectedCids map[cid.Cid][]cid.Cid // expected in this order @@ -90,13 +71,13 @@ func TestHTTPRetriever(t *testing.T) { { name: "single, one peer, success", requests: map[cid.Cid]types.RetrievalID{cid1: rid1}, - remotes: map[cid.Cid][]httpRemote{ + remotes: map[cid.Cid][]testutil.MockRoundTripRemote{ cid1: { { - peer: cid1Cands[0].MinerPeer, - lsys: makeLsys(tbc1.AllBlocks()), - sel: allSelector, - respondAt: startTime.Add(initialPause + time.Millisecond*40), + Peer: cid1Cands[0].MinerPeer, + LinkSystem: *makeLsys(tbc1.AllBlocks()), + Selector: allSelector, + RespondAt: startTime.Add(initialPause + time.Millisecond*40), }, }, }, @@ -152,21 +133,21 @@ func TestHTTPRetriever(t *testing.T) { { name: "two parallel, one peer each, success", requests: map[cid.Cid]types.RetrievalID{cid1: rid1, cid2: rid2}, - remotes: map[cid.Cid][]httpRemote{ + remotes: map[cid.Cid][]testutil.MockRoundTripRemote{ cid1: { { - peer: cid1Cands[0].MinerPeer, - lsys: makeLsys(tbc1.AllBlocks()), - sel: allSelector, - respondAt: startTime.Add(initialPause + time.Millisecond*40), + Peer: cid1Cands[0].MinerPeer, + LinkSystem: *makeLsys(tbc1.AllBlocks()), + Selector: allSelector, + RespondAt: startTime.Add(initialPause + time.Millisecond*40), }, }, cid2: { { - peer: cid2Cands[0].MinerPeer, - lsys: makeLsys(tbc2.AllBlocks()), - sel: allSelector, - respondAt: startTime.Add(initialPause + time.Millisecond*10), + Peer: cid2Cands[0].MinerPeer, + LinkSystem: *makeLsys(tbc2.AllBlocks()), + Selector: allSelector, + RespondAt: startTime.Add(initialPause + time.Millisecond*10), }, }, }, @@ -256,28 +237,28 @@ func TestHTTPRetriever(t *testing.T) { { name: "single, multiple errors", requests: map[cid.Cid]types.RetrievalID{cid1: rid1}, - remotes: map[cid.Cid][]httpRemote{ + remotes: map[cid.Cid][]testutil.MockRoundTripRemote{ cid1: { { - peer: cid1Cands[0].MinerPeer, - lsys: makeLsys(nil), - sel: allSelector, - respondAt: startTime.Add(initialPause + time.Millisecond*10), - malformed: true, + Peer: cid1Cands[0].MinerPeer, + LinkSystem: *makeLsys(nil), + Selector: allSelector, + RespondAt: startTime.Add(initialPause + time.Millisecond*10), + Malformed: true, }, { - peer: cid1Cands[1].MinerPeer, - lsys: makeLsys(nil), - sel: allSelector, - respondAt: startTime.Add(initialPause + time.Millisecond*20), - malformed: true, + Peer: cid1Cands[1].MinerPeer, + LinkSystem: *makeLsys(nil), + Selector: allSelector, + RespondAt: startTime.Add(initialPause + time.Millisecond*20), + Malformed: true, }, { - peer: cid1Cands[2].MinerPeer, - lsys: makeLsys(nil), - sel: allSelector, - respondAt: startTime.Add(initialPause + time.Millisecond*30), - malformed: true, + Peer: cid1Cands[2].MinerPeer, + LinkSystem: *makeLsys(nil), + Selector: allSelector, + RespondAt: startTime.Add(initialPause + time.Millisecond*30), + Malformed: true, }, }, }, @@ -355,27 +336,27 @@ func TestHTTPRetriever(t *testing.T) { { name: "single, multiple errors, one success", requests: map[cid.Cid]types.RetrievalID{cid1: rid1}, - remotes: map[cid.Cid][]httpRemote{ + remotes: map[cid.Cid][]testutil.MockRoundTripRemote{ cid1: { { - peer: cid1Cands[0].MinerPeer, - lsys: makeLsys(nil), - sel: allSelector, - respondAt: startTime.Add(initialPause + time.Millisecond*10), - malformed: true, + Peer: cid1Cands[0].MinerPeer, + LinkSystem: *makeLsys(nil), + Selector: allSelector, + RespondAt: startTime.Add(initialPause + time.Millisecond*10), + Malformed: true, }, { - peer: cid1Cands[1].MinerPeer, - lsys: makeLsys(nil), - sel: allSelector, - respondAt: startTime.Add(initialPause + time.Millisecond*20), - malformed: true, + Peer: cid1Cands[1].MinerPeer, + LinkSystem: *makeLsys(nil), + Selector: allSelector, + RespondAt: startTime.Add(initialPause + time.Millisecond*20), + Malformed: true, }, { - peer: cid1Cands[2].MinerPeer, - lsys: makeLsys(tbc1.AllBlocks()), - sel: allSelector, - respondAt: startTime.Add(initialPause + time.Millisecond*30), + Peer: cid1Cands[2].MinerPeer, + LinkSystem: *makeLsys(tbc1.AllBlocks()), + Selector: allSelector, + RespondAt: startTime.Add(initialPause + time.Millisecond*30), }, }, }, @@ -472,13 +453,13 @@ func TestHTTPRetriever(t *testing.T) { { name: "single, one peer, partial served", requests: map[cid.Cid]types.RetrievalID{cid1: rid1}, - remotes: map[cid.Cid][]httpRemote{ + remotes: map[cid.Cid][]testutil.MockRoundTripRemote{ cid1: { { - peer: cid1Cands[0].MinerPeer, - lsys: makeLsys(tbc1.AllBlocks()[0:50]), - sel: allSelector, - respondAt: startTime.Add(initialPause + time.Millisecond*40), + Peer: cid1Cands[0].MinerPeer, + LinkSystem: *makeLsys(tbc1.AllBlocks()[0:50]), + Selector: allSelector, + RespondAt: startTime.Add(initialPause + time.Millisecond*40), }, }, }, @@ -537,19 +518,7 @@ func TestHTTPRetriever(t *testing.T) { clock.Set(startTime) awaitReceivedCandidates := make(chan struct{}, 1) - getRemote := func(cid cid.Cid, maddr string) httpRemote { - remotes, ok := testCase.remotes[cid] - req.True(ok) - for _, remote := range remotes { - if remote.peer.Addrs[0].String() == maddr { - return remote - } - } - t.Fatal("remote not found") - return httpRemote{} - } - - roundTripper := NewCannedBytesRoundTripper(t, ctx, clock, remoteBlockDuration, testCase.requestPath, testCase.requestScope, getRemote) + roundTripper := testutil.NewMockRoundTripper(t, ctx, clock, remoteBlockDuration, testCase.requestPath, testCase.requestScope, testCase.remotes) client := &http.Client{Transport: roundTripper} // customCompare lets us order candidates when they queue, since we currently // have no other way to deterministically order them for testing. @@ -627,10 +596,10 @@ func TestHTTPRetriever(t *testing.T) { } } -func toCandidates(root cid.Cid, remotes []httpRemote) []types.RetrievalCandidate { +func toCandidates(root cid.Cid, remotes []testutil.MockRoundTripRemote) []types.RetrievalCandidate { candidates := make([]types.RetrievalCandidate, len(remotes)) for i, r := range remotes { - candidates[i] = toCandidate(root, r.peer) + candidates[i] = toCandidate(root, r.Peer) } return candidates } @@ -660,289 +629,3 @@ func (ba *blockAccounter) StorageWriteOpener(lctx linking.LinkContext) (io.Write return wc(l) }, err } - -type cannedBytesRoundTripper struct { - StartsCh chan peer.ID - StatsCh chan testutil.RemoteStats - EndsCh chan peer.ID - - t *testing.T - ctx context.Context - clock *clock.Mock - remoteBlockDuration time.Duration - expectedPath map[cid.Cid]string - expectedScope map[cid.Cid]types.CarScope - getRemote func(cid cid.Cid, maddr string) httpRemote -} - -var _ http.RoundTripper = (*cannedBytesRoundTripper)(nil) - -func NewCannedBytesRoundTripper( - t *testing.T, - ctx context.Context, - clock *clock.Mock, - remoteBlockDuration time.Duration, - expectedPath map[cid.Cid]string, - expectedScope map[cid.Cid]types.CarScope, - getRemote func(cid cid.Cid, maddr string) httpRemote, -) *cannedBytesRoundTripper { - return &cannedBytesRoundTripper{ - make(chan peer.ID, 32), - make(chan testutil.RemoteStats, 32), - make(chan peer.ID, 32), - t, - ctx, - clock, - remoteBlockDuration, - expectedPath, - expectedScope, - getRemote, - } -} - -func (c *cannedBytesRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - us := strings.Split(req.URL.Path, "/") - require.True(c.t, len(us) > 2) - require.Equal(c.t, us[1], "ipfs") - root, err := cid.Parse(us[2]) - require.NoError(c.t, err) - path := strings.Join(us[3:], "/") - expectedPath, ok := c.expectedPath[root] - if !ok { - require.Equal(c.t, path, "") - } else { - require.Equal(c.t, path, expectedPath) - } - expectedScope := types.CarScopeAll - if scope, ok := c.expectedScope[root]; ok { - expectedScope = scope - } - require.Equal(c.t, req.URL.RawQuery, fmt.Sprintf("car-scope=%s", expectedScope)) - ip := req.URL.Hostname() - port := req.URL.Port() - maddr := fmt.Sprintf("/ip4/%s/tcp/%s/http", ip, port) - remote := c.getRemote(root, maddr) - c.StartsCh <- remote.peer.ID - - sleepFor := c.clock.Until(remote.respondAt) - if sleepFor > 0 { - select { - case <-c.ctx.Done(): - return nil, c.ctx.Err() - case <-c.clock.After(sleepFor): - } - } - - makeBody := func(root cid.Cid, maddr string) io.ReadCloser { - carR, carW := io.Pipe() - statsCh := traverseCar( - c.t, - c.ctx, - remote.peer.ID, - c.clock, - remote.respondAt, - c.remoteBlockDuration, - carW, - remote.malformed, - remote.lsys, - root, - remote.sel, - ) - go func() { - select { - case <-c.ctx.Done(): - return - case stats, ok := <-statsCh: - if !ok { - return - } - c.StatsCh <- stats - } - }() - return carR - } - return &http.Response{ - StatusCode: http.StatusOK, - Body: &deferredReader{root: root, maddr: maddr, makeBody: makeBody, end: func() { c.EndsCh <- remote.peer.ID }}, - }, nil -} - -func (c *cannedBytesRoundTripper) VerifyConnectionsReceived(ctx context.Context, t *testing.T, afterStart time.Duration, expectedConnections []peer.ID) { - if len(expectedConnections) > 0 { - require.FailNowf(t, "unexpected ConnectionsReceived", "@ %s", afterStart) - } -} - -func (c *cannedBytesRoundTripper) VerifyRetrievalsReceived(ctx context.Context, t *testing.T, afterStart time.Duration, expectedRetrievals []peer.ID) { - retrievals := make([]peer.ID, 0, len(expectedRetrievals)) - for i := 0; i < len(expectedRetrievals); i++ { - select { - case retrieval := <-c.StartsCh: - retrievals = append(retrievals, retrieval) - case <-ctx.Done(): - require.FailNowf(t, "failed to receive expected retrievals", "expected %d, received %d @ %s", len(expectedRetrievals), i, afterStart) - } - } - require.ElementsMatch(t, expectedRetrievals, retrievals) -} - -func (c *cannedBytesRoundTripper) VerifyRetrievalsServed(ctx context.Context, t *testing.T, afterStart time.Duration, expectedServed []testutil.RemoteStats) { - remoteStats := make([]testutil.RemoteStats, 0, len(expectedServed)) - for i := 0; i < len(expectedServed); i++ { - select { - case stats := <-c.StatsCh: - remoteStats = append(remoteStats, stats) - case <-ctx.Done(): - require.FailNowf(t, "failed to receive expected served", "expected %d, received %d @ %s", len(expectedServed), i, afterStart) - } - } - require.ElementsMatch(t, expectedServed, remoteStats) -} - -func (c *cannedBytesRoundTripper) VerifyRetrievalsCompleted(ctx context.Context, t *testing.T, afterStart time.Duration, expectedRetrievals []peer.ID) { - retrievals := make([]peer.ID, 0, len(expectedRetrievals)) - for i := 0; i < len(expectedRetrievals); i++ { - select { - case retrieval := <-c.EndsCh: - retrievals = append(retrievals, retrieval) - case <-ctx.Done(): - require.FailNowf(t, "failed to complete expected retrievals", "expected %d, received %d @ %s", len(expectedRetrievals), i, afterStart) - } - } - require.ElementsMatch(t, expectedRetrievals, retrievals) -} - -// deferredReader is simply a Reader that lazily calls makeBody on the first Read -// so we don't begin CAR generation if the HTTP response body never gets read by -// the client. -type deferredReader struct { - root cid.Cid - maddr string - makeBody func(cid.Cid, string) io.ReadCloser - end func() - - r io.ReadCloser - once sync.Once -} - -var _ io.ReadCloser = (*deferredReader)(nil) - -func (d *deferredReader) Read(p []byte) (n int, err error) { - d.once.Do(func() { - d.r = d.makeBody(d.root, d.maddr) - }) - n, err = d.r.Read(p) - if err == io.EOF { - d.end() - } - return n, err -} - -func (d *deferredReader) Close() error { - if d.r != nil { - return d.r.Close() - } - return nil -} - -// given a writer (carW), a linkSystem, a root CID and a selector, traverse the graph -// and write the blocks in CARv1 format to the writer. Return a channel that will -// receive basic stats on what was written _after_ the write is finished. -func traverseCar( - t *testing.T, - ctx context.Context, - id peer.ID, - clock *clock.Mock, - startTime time.Time, - blockDuration time.Duration, - carW io.WriteCloser, - malformed bool, - lsys *linking.LinkSystem, - root cid.Cid, - selNode ipld.Node, -) chan testutil.RemoteStats { - - req := require.New(t) - - sel, err := selector.CompileSelector(selNode) - req.NoError(err) - - statsCh := make(chan testutil.RemoteStats, 1) - go func() { - stats := testutil.RemoteStats{ - Peer: id, - Root: root, - Blocks: make([]cid.Cid, 0), - } - - defer func() { - statsCh <- stats - req.NoError(carW.Close()) - }() - - if malformed { - carW.Write([]byte("nope, this is not what you're looking for")) - return - } - - // instantiating this writes a CARv1 header and waits for more Put()s - carWriter, err := storage.NewWritable(carW, []cid.Cid{root}, car.WriteAsCarV1(true), car.AllowDuplicatePuts(false)) - req.NoError(err) - - // intercept the StorageReadOpener of the LinkSystem so that for each - // read that the traverser performs, we take that block and Put() it - // to the CARv1 writer. - originalSRO := lsys.StorageReadOpener - lsys.StorageReadOpener = func(lc linking.LinkContext, lnk datamodel.Link) (io.Reader, error) { - r, err := originalSRO(lc, lnk) - if err != nil { - return nil, err - } - byts, err := io.ReadAll(r) - if err != nil { - return nil, err - } - err = carWriter.Put(ctx, lnk.(cidlink.Link).Cid.KeyString(), byts) - req.NoError(err) - stats.Blocks = append(stats.Blocks, lnk.(cidlink.Link).Cid) - stats.ByteCount += uint64(len(byts)) // only the length of the bytes, not the rest of the CAR infrastructure - - // ensure there is blockDuration between each block send - sendAt := startTime.Add(blockDuration * time.Duration(len(stats.Blocks))) - if clock.Until(sendAt) > 0 { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-clock.After(clock.Until(sendAt)): - time.Sleep(1 * time.Millisecond) // let em goroutines breathe - } - } - return bytes.NewReader(byts), nil - } - - // load and register the root link so it's pushed to the CAR since - // the traverser won't load it (we feed the traverser the rood _node_ - // not the link) - var proto datamodel.NodePrototype = basicnode.Prototype.Any - if root.Prefix().Codec == cid.DagProtobuf { - proto = dagpb.Type.PBNode - } - rootNode, err := lsys.Load(linking.LinkContext{}, cidlink.Link{Cid: root}, proto) - if err != nil { - stats.Err = struct{}{} - } else { - // begin traversal - err := traversal.Progress{ - Cfg: &traversal.Config{ - Ctx: ctx, - LinkSystem: *lsys, - LinkTargetNodePrototypeChooser: dagpb.AddSupportToChooser(basicnode.Chooser), - }, - }.WalkAdv(rootNode, sel, func(p traversal.Progress, n datamodel.Node, vr traversal.VisitReason) error { return nil }) - if err != nil { - stats.Err = struct{}{} - } - } - }() - return statsCh -}