From eddbf615a5a533b0c1c044bdfd898d2dc632d108 Mon Sep 17 00:00:00 2001 From: Hannah Howard Date: Fri, 14 Jan 2022 14:40:10 -0800 Subject: [PATCH] Removing CID Lists part one: unblocking technical obstacles (#292) * feat(channels): remove cid lists for sender remove a synchronous disk based lookup of CID lists * feat(channels): don't block on checking cid lists * fix(channels): fix potential race in block index cache --- channelmonitor/channelmonitor_test.go | 8 ++ channels/block_index_cache.go | 63 ++++++++++++ channels/channel_state.go | 20 ++++ channels/channels.go | 70 +++++++------ channels/channels_fsm.go | 10 +- channels/channels_test.go | 27 +++-- channels/internal/internalchannel.go | 7 +- channels/internal/internalchannel_cbor_gen.go | 98 ++++++++++++++++++- impl/events.go | 12 +-- impl/initiating_test.go | 4 +- impl/responding_test.go | 16 +-- impl/restart_integration_test.go | 11 --- transport.go | 6 +- transport/graphsync/graphsync.go | 6 +- transport/graphsync/graphsync_test.go | 14 ++- types.go | 8 ++ 16 files changed, 301 insertions(+), 79 deletions(-) create mode 100644 channels/block_index_cache.go diff --git a/channelmonitor/channelmonitor_test.go b/channelmonitor/channelmonitor_test.go index ed833963..b710530b 100644 --- a/channelmonitor/channelmonitor_test.go +++ b/channelmonitor/channelmonitor_test.go @@ -625,3 +625,11 @@ func (m *mockChannelState) ReceivedCidsTotal() int64 { func (m *mockChannelState) MissingCids() []cid.Cid { panic("implement me") } + +func (m *mockChannelState) QueuedCidsTotal() int64 { + panic("implement me") +} + +func (m *mockChannelState) SentCidsTotal() int64 { + panic("implement me") +} diff --git a/channels/block_index_cache.go b/channels/block_index_cache.go new file mode 100644 index 00000000..490f77fd --- /dev/null +++ b/channels/block_index_cache.go @@ -0,0 +1,63 @@ +package channels + +import ( + "sync" + "sync/atomic" + + datatransfer "github.com/filecoin-project/go-data-transfer" +) + +type readOriginalFn func(datatransfer.ChannelID) (int64, error) + +type blockIndexKey struct { + evt datatransfer.EventCode + chid datatransfer.ChannelID +} +type blockIndexCache struct { + lk sync.RWMutex + values map[blockIndexKey]*int64 +} + +func newBlockIndexCache() *blockIndexCache { + return &blockIndexCache{ + values: make(map[blockIndexKey]*int64), + } +} + +func (bic *blockIndexCache) getValue(evt datatransfer.EventCode, chid datatransfer.ChannelID, readFromOriginal readOriginalFn) (*int64, error) { + idxKey := blockIndexKey{evt, chid} + bic.lk.RLock() + value := bic.values[idxKey] + bic.lk.RUnlock() + if value != nil { + return value, nil + } + bic.lk.Lock() + defer bic.lk.Unlock() + value = bic.values[idxKey] + if value != nil { + return value, nil + } + newValue, err := readFromOriginal(chid) + if err != nil { + return nil, err + } + bic.values[idxKey] = &newValue + return &newValue, nil +} + +func (bic *blockIndexCache) updateIfGreater(evt datatransfer.EventCode, chid datatransfer.ChannelID, newIndex int64, readFromOriginal readOriginalFn) (bool, error) { + value, err := bic.getValue(evt, chid, readFromOriginal) + if err != nil { + return false, err + } + for { + currentIndex := atomic.LoadInt64(value) + if newIndex <= currentIndex { + return false, nil + } + if atomic.CompareAndSwapInt64(value, currentIndex, newIndex) { + return true, nil + } + } +} diff --git a/channels/channel_state.go b/channels/channel_state.go index 6a26388c..76141249 100644 --- a/channels/channel_state.go +++ b/channels/channel_state.go @@ -43,6 +43,12 @@ type channelState struct { // number of blocks that have been received, including blocks that are // present in more than one place in the DAG receivedBlocksTotal int64 + // Number of blocks that have been queued, including blocks that are + // present in more than one place in the DAG + queuedBlocksTotal int64 + // Number of blocks that have been sent, including blocks that are + // present in more than one place in the DAG + sentBlocksTotal int64 // more informative status on a channel message string // additional vouchers @@ -128,6 +134,18 @@ func (c channelState) ReceivedCidsTotal() int64 { return c.receivedBlocksTotal } +// QueuedCidsTotal returns the number of (non-unique) cids queued so far +// on the channel - note that a block can exist in more than one place in the DAG +func (c channelState) QueuedCidsTotal() int64 { + return c.queuedBlocksTotal +} + +// SentCidsTotal returns the number of (non-unique) cids sent so far +// on the channel - note that a block can exist in more than one place in the DAG +func (c channelState) SentCidsTotal() int64 { + return c.sentBlocksTotal +} + // Sender returns the peer id for the node that is sending data func (c channelState) Sender() peer.ID { return c.sender } @@ -230,6 +248,8 @@ func fromInternalChannelState(c internal.ChannelState, voucherDecoder DecoderByT sent: c.Sent, received: c.Received, receivedBlocksTotal: c.ReceivedBlocksTotal, + queuedBlocksTotal: c.QueuedBlocksTotal, + sentBlocksTotal: c.SentBlocksTotal, message: c.Message, vouchers: c.Vouchers, voucherResults: c.VoucherResults, diff --git a/channels/channels.go b/channels/channels.go index b5c1b048..90c73a6c 100644 --- a/channels/channels.go +++ b/channels/channels.go @@ -56,6 +56,7 @@ type Channels struct { notifier Notifier voucherDecoder DecoderByTypeFunc voucherResultDecoder DecoderByTypeFunc + blockIndexCache *blockIndexCache stateMachines fsm.Group migrateStateMachines func(context.Context) error seenCIDs *cidsets.CIDSetManager @@ -85,6 +86,7 @@ func New(ds datastore.Batching, voucherDecoder: voucherDecoder, voucherResultDecoder: voucherResultDecoder, } + c.blockIndexCache = newBlockIndexCache() channelMigrations, err := migrations.GetChannelStateMigrations(selfPeer, cidLists) if err != nil { return nil, err @@ -234,38 +236,48 @@ func (c *Channels) CompleteCleanupOnRestart(chid datatransfer.ChannelID) error { return c.send(chid, datatransfer.CompleteCleanupOnRestart) } -// Returns true if this is the first time the block has been sent -func (c *Channels) DataSent(chid datatransfer.ChannelID, k cid.Cid, delta uint64) (bool, error) { - return c.fireProgressEvent(chid, datatransfer.DataSent, datatransfer.DataSentProgress, k, delta) -} - -// Returns true if this is the first time the block has been queued -func (c *Channels) DataQueued(chid datatransfer.ChannelID, k cid.Cid, delta uint64) (bool, error) { - return c.fireProgressEvent(chid, datatransfer.DataQueued, datatransfer.DataQueuedProgress, k, delta) +func (c *Channels) getQueuedIndex(chid datatransfer.ChannelID) (int64, error) { + chst, err := c.GetByID(context.TODO(), chid) + if err != nil { + return 0, err + } + return chst.QueuedCidsTotal(), nil } -// Returns true if this is the first time the block has been received -func (c *Channels) DataReceived(chid datatransfer.ChannelID, k cid.Cid, delta uint64, index int64) (bool, error) { - if err := c.checkChannelExists(chid, datatransfer.DataReceived); err != nil { - return false, err +func (c *Channels) getReceivedIndex(chid datatransfer.ChannelID) (int64, error) { + chst, err := c.GetByID(context.TODO(), chid) + if err != nil { + return 0, err } + return chst.ReceivedCidsTotal(), nil +} - // Check if the block has already been seen - sid := seenCidsSetID(chid, datatransfer.DataReceived) - seen, err := c.seenCIDs.InsertSetCID(sid, k) +func (c *Channels) getSentIndex(chid datatransfer.ChannelID) (int64, error) { + chst, err := c.GetByID(context.TODO(), chid) if err != nil { - return false, err + return 0, err } + return chst.SentCidsTotal(), nil +} - // If the block has not been seen before, fire the progress event - if !seen { - if err := c.stateMachines.Send(chid, datatransfer.DataReceivedProgress, delta); err != nil { - return false, err - } - } +func (c *Channels) DataSent(chid datatransfer.ChannelID, k cid.Cid, delta uint64, index int64, unique bool) (bool, error) { + return c.fireProgressEvent(chid, datatransfer.DataSent, datatransfer.DataSentProgress, k, delta, index, unique, c.getSentIndex) +} - // Fire the regular event - return !seen, c.stateMachines.Send(chid, datatransfer.DataReceived, index) +func (c *Channels) DataQueued(chid datatransfer.ChannelID, k cid.Cid, delta uint64, index int64, unique bool) (bool, error) { + return c.fireProgressEvent(chid, datatransfer.DataQueued, datatransfer.DataQueuedProgress, k, delta, index, unique, c.getQueuedIndex) +} + +// Returns true if this is the first time the block has been received +func (c *Channels) DataReceived(chid datatransfer.ChannelID, k cid.Cid, delta uint64, index int64, unique bool) (bool, error) { + new, err := c.fireProgressEvent(chid, datatransfer.DataReceived, datatransfer.DataReceivedProgress, k, delta, index, unique, c.getReceivedIndex) + // TODO: remove when ReceivedCids and legacy protocol is removed + // write the seen received cids, but write async in order to avoid blocking processing + if err == nil { + sid := seenCidsSetID(chid, datatransfer.DataReceived) + _, _ = c.seenCIDs.InsertSetCID(sid, k) + } + return new, err } // PauseInitiator pauses the initator of this channel @@ -409,27 +421,25 @@ func (c *Channels) removeSeenCIDCaches(chid datatransfer.ChannelID) error { // fire both DataSent AND DataSentProgress. // If a block is resent, the method will fire DataSent but not DataSentProgress. // Returns true if the block is new (both the event and a progress event were fired). -func (c *Channels) fireProgressEvent(chid datatransfer.ChannelID, evt datatransfer.EventCode, progressEvt datatransfer.EventCode, k cid.Cid, delta uint64) (bool, error) { +func (c *Channels) fireProgressEvent(chid datatransfer.ChannelID, evt datatransfer.EventCode, progressEvt datatransfer.EventCode, k cid.Cid, delta uint64, index int64, unique bool, readFromOriginal readOriginalFn) (bool, error) { if err := c.checkChannelExists(chid, evt); err != nil { return false, err } - // Check if the block has already been seen - sid := seenCidsSetID(chid, evt) - seen, err := c.seenCIDs.InsertSetCID(sid, k) + isNewIndex, err := c.blockIndexCache.updateIfGreater(evt, chid, index, readFromOriginal) if err != nil { return false, err } // If the block has not been seen before, fire the progress event - if !seen { + if unique && isNewIndex { if err := c.stateMachines.Send(chid, progressEvt, delta); err != nil { return false, err } } // Fire the regular event - return !seen, c.stateMachines.Send(chid, evt) + return unique && isNewIndex, c.stateMachines.Send(chid, evt, index) } func (c *Channels) send(chid datatransfer.ChannelID, code datatransfer.EventCode, args ...interface{}) error { diff --git a/channels/channels_fsm.go b/channels/channels_fsm.go index f1121b61..712d6dd6 100644 --- a/channels/channels_fsm.go +++ b/channels/channels_fsm.go @@ -79,7 +79,10 @@ var ChannelEvents = fsm.Events{ fsm.Event(datatransfer.DataSent). FromMany(transferringStates...).ToNoChange(). From(datatransfer.TransferFinished).ToNoChange(). - Action(func(chst *internal.ChannelState) error { + Action(func(chst *internal.ChannelState, sentBlocksTotal int64) error { + if sentBlocksTotal > chst.SentBlocksTotal { + chst.SentBlocksTotal = sentBlocksTotal + } chst.AddLog("") return nil }), @@ -94,7 +97,10 @@ var ChannelEvents = fsm.Events{ fsm.Event(datatransfer.DataQueued). FromMany(transferringStates...).ToNoChange(). From(datatransfer.TransferFinished).ToNoChange(). - Action(func(chst *internal.ChannelState) error { + Action(func(chst *internal.ChannelState, queuedBlocksTotal int64) error { + if queuedBlocksTotal > chst.QueuedBlocksTotal { + chst.QueuedBlocksTotal = queuedBlocksTotal + } chst.AddLog("") return nil }), diff --git a/channels/channels_test.go b/channels/channels_test.go index 6a62e3d6..2dd365da 100644 --- a/channels/channels_test.go +++ b/channels/channels_test.go @@ -158,13 +158,13 @@ func TestChannels(t *testing.T) { require.Equal(t, datatransfer.TransferFinished, state.Status()) // send a data-sent event and ensure it's a no-op - _, err = channelList.DataSent(chid, cids[1], 1) + _, err = channelList.DataSent(chid, cids[1], 1, 1, true) require.NoError(t, err) state = checkEvent(ctx, t, received, datatransfer.DataSent) require.Equal(t, datatransfer.TransferFinished, state.Status()) // send a data-queued event and ensure it's a no-op. - _, err = channelList.DataQueued(chid, cids[1], 1) + _, err = channelList.DataQueued(chid, cids[1], 1, 1, true) require.NoError(t, err) state = checkEvent(ctx, t, received, datatransfer.DataQueued) require.Equal(t, datatransfer.TransferFinished, state.Status()) @@ -188,7 +188,7 @@ func TestChannels(t *testing.T) { require.Equal(t, uint64(0), state.Sent()) require.Empty(t, state.ReceivedCids()) - isNew, err := channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[0], 50, 1) + isNew, err := channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[0], 50, 1, true) require.NoError(t, err) _ = checkEvent(ctx, t, received, datatransfer.DataReceivedProgress) require.True(t, isNew) @@ -197,7 +197,7 @@ func TestChannels(t *testing.T) { require.Equal(t, uint64(0), state.Sent()) require.Equal(t, []cid.Cid{cids[0]}, state.ReceivedCids()) - isNew, err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 100) + isNew, err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 100, 1, true) require.NoError(t, err) _ = checkEvent(ctx, t, received, datatransfer.DataSentProgress) require.True(t, isNew) @@ -206,16 +206,24 @@ func TestChannels(t *testing.T) { require.Equal(t, uint64(100), state.Sent()) require.Equal(t, []cid.Cid{cids[0]}, state.ReceivedCids()) + // send block again has no effect + isNew, err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 100, 1, true) + require.NoError(t, err) + require.False(t, isNew) + state = checkEvent(ctx, t, received, datatransfer.DataSent) + require.Equal(t, uint64(50), state.Received()) + require.Equal(t, uint64(100), state.Sent()) + // errors if channel does not exist - isNew, err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[1], 200, 2) + isNew, err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[1], 200, 2, true) require.True(t, xerrors.As(err, new(*channels.ErrNotFound))) require.False(t, isNew) - isNew, err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[1], 200) + isNew, err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[1], Responder: peers[0], ID: tid1}, cids[1], 200, 2, true) require.True(t, xerrors.As(err, new(*channels.ErrNotFound))) require.Equal(t, []cid.Cid{cids[0]}, state.ReceivedCids()) require.False(t, isNew) - isNew, err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 50, 2) + isNew, err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 50, 2, true) require.NoError(t, err) _ = checkEvent(ctx, t, received, datatransfer.DataReceivedProgress) require.True(t, isNew) @@ -224,7 +232,7 @@ func TestChannels(t *testing.T) { require.Equal(t, uint64(100), state.Sent()) require.ElementsMatch(t, []cid.Cid{cids[0], cids[1]}, state.ReceivedCids()) - isNew, err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 25) + isNew, err = channelList.DataSent(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[1], 25, 2, false) require.NoError(t, err) require.False(t, isNew) state = checkEvent(ctx, t, received, datatransfer.DataSent) @@ -232,12 +240,13 @@ func TestChannels(t *testing.T) { require.Equal(t, uint64(100), state.Sent()) require.ElementsMatch(t, []cid.Cid{cids[0], cids[1]}, state.ReceivedCids()) - isNew, err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[0], 50, 3) + isNew, err = channelList.DataReceived(datatransfer.ChannelID{Initiator: peers[0], Responder: peers[1], ID: tid1}, cids[0], 50, 3, false) require.NoError(t, err) require.False(t, isNew) state = checkEvent(ctx, t, received, datatransfer.DataReceived) require.Equal(t, uint64(100), state.Received()) require.Equal(t, uint64(100), state.Sent()) + require.ElementsMatch(t, []cid.Cid{cids[0], cids[1]}, state.ReceivedCids()) }) diff --git a/channels/internal/internalchannel.go b/channels/internal/internalchannel.go index 28c8ddf7..f233d673 100644 --- a/channels/internal/internalchannel.go +++ b/channels/internal/internalchannel.go @@ -63,7 +63,12 @@ type ChannelState struct { // Number of blocks that have been received, including blocks that are // present in more than one place in the DAG ReceivedBlocksTotal int64 - + // Number of blocks that have been queued, including blocks that are + // present in more than one place in the DAG + QueuedBlocksTotal int64 + // Number of blocks that have been sent, including blocks that are + // present in more than one place in the DAG + SentBlocksTotal int64 // Stages traces the execution fo a data transfer. // // EXPERIMENTAL; subject to change. diff --git a/channels/internal/internalchannel_cbor_gen.go b/channels/internal/internalchannel_cbor_gen.go index 7f5f406b..84536672 100644 --- a/channels/internal/internalchannel_cbor_gen.go +++ b/channels/internal/internalchannel_cbor_gen.go @@ -23,7 +23,7 @@ func (t *ChannelState) MarshalCBOR(w io.Writer) error { _, err := w.Write(cbg.CborNull) return err } - if _, err := w.Write([]byte{179}); err != nil { + if _, err := w.Write([]byte{181}); err != nil { return err } @@ -367,6 +367,50 @@ func (t *ChannelState) MarshalCBOR(w io.Writer) error { } } + // t.QueuedBlocksTotal (int64) (int64) + if len("QueuedBlocksTotal") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"QueuedBlocksTotal\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("QueuedBlocksTotal"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("QueuedBlocksTotal")); err != nil { + return err + } + + if t.QueuedBlocksTotal >= 0 { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.QueuedBlocksTotal)); err != nil { + return err + } + } else { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajNegativeInt, uint64(-t.QueuedBlocksTotal-1)); err != nil { + return err + } + } + + // t.SentBlocksTotal (int64) (int64) + if len("SentBlocksTotal") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"SentBlocksTotal\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("SentBlocksTotal"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("SentBlocksTotal")); err != nil { + return err + } + + if t.SentBlocksTotal >= 0 { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.SentBlocksTotal)); err != nil { + return err + } + } else { + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajNegativeInt, uint64(-t.SentBlocksTotal-1)); err != nil { + return err + } + } + // t.Stages (datatransfer.ChannelStages) (struct) if len("Stages") > cbg.MaxLength { return xerrors.Errorf("Value in field \"Stages\" was too long") @@ -709,6 +753,58 @@ func (t *ChannelState) UnmarshalCBOR(r io.Reader) error { t.ReceivedBlocksTotal = int64(extraI) } + // t.QueuedBlocksTotal (int64) (int64) + case "QueuedBlocksTotal": + { + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) + var extraI int64 + if err != nil { + return err + } + switch maj { + case cbg.MajUnsignedInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 positive overflow") + } + case cbg.MajNegativeInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 negative oveflow") + } + extraI = -1 - extraI + default: + return fmt.Errorf("wrong type for int64 field: %d", maj) + } + + t.QueuedBlocksTotal = int64(extraI) + } + // t.SentBlocksTotal (int64) (int64) + case "SentBlocksTotal": + { + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) + var extraI int64 + if err != nil { + return err + } + switch maj { + case cbg.MajUnsignedInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 positive overflow") + } + case cbg.MajNegativeInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 negative oveflow") + } + extraI = -1 - extraI + default: + return fmt.Errorf("wrong type for int64 field: %d", maj) + } + + t.SentBlocksTotal = int64(extraI) + } // t.Stages (datatransfer.ChannelStages) (struct) case "Stages": diff --git a/impl/events.go b/impl/events.go index ced32ad6..442b8e77 100644 --- a/impl/events.go +++ b/impl/events.go @@ -41,7 +41,7 @@ func (m *manager) OnChannelOpened(chid datatransfer.ChannelID) error { // It fires an event on the channel, updating the sum of received data and // calls revalidators so they can pause / resume the channel or send a // message over the transport. -func (m *manager) OnDataReceived(chid datatransfer.ChannelID, link ipld.Link, size uint64, index int64) error { +func (m *manager) OnDataReceived(chid datatransfer.ChannelID, link ipld.Link, size uint64, index int64, unique bool) error { ctx, _ := m.spansIndex.SpanForChannel(context.TODO(), chid) ctx, span := otel.Tracer("data-transfer").Start(ctx, "dataReceived", trace.WithAttributes( attribute.String("channelID", chid.String()), @@ -51,7 +51,7 @@ func (m *manager) OnDataReceived(chid datatransfer.ChannelID, link ipld.Link, si )) defer span.End() - isNew, err := m.channels.DataReceived(chid, link.(cidlink.Link).Cid, size, index) + isNew, err := m.channels.DataReceived(chid, link.(cidlink.Link).Cid, size, index, unique) if err != nil { return err } @@ -97,7 +97,7 @@ func (m *manager) OnDataReceived(chid datatransfer.ChannelID, link ipld.Link, si // up some data to be sent to the requester. // It fires an event on the channel, updating the sum of queued data and calls // revalidators so they can pause / resume or send a message over the transport. -func (m *manager) OnDataQueued(chid datatransfer.ChannelID, link ipld.Link, size uint64) (datatransfer.Message, error) { +func (m *manager) OnDataQueued(chid datatransfer.ChannelID, link ipld.Link, size uint64, index int64, unique bool) (datatransfer.Message, error) { // The transport layer reports that some data has been queued up to be sent // to the requester, so fire a DataQueued event on the channels state // machine. @@ -110,7 +110,7 @@ func (m *manager) OnDataQueued(chid datatransfer.ChannelID, link ipld.Link, size )) defer span.End() - isNew, err := m.channels.DataQueued(chid, link.(cidlink.Link).Cid, size) + isNew, err := m.channels.DataQueued(chid, link.(cidlink.Link).Cid, size, index, unique) if err != nil { return nil, err } @@ -147,7 +147,7 @@ func (m *manager) OnDataQueued(chid datatransfer.ChannelID, link ipld.Link, size return nil, nil } -func (m *manager) OnDataSent(chid datatransfer.ChannelID, link ipld.Link, size uint64) error { +func (m *manager) OnDataSent(chid datatransfer.ChannelID, link ipld.Link, size uint64, index int64, unique bool) error { ctx, _ := m.spansIndex.SpanForChannel(context.TODO(), chid) ctx, span := otel.Tracer("data-transfer").Start(ctx, "dataSent", trace.WithAttributes( @@ -157,7 +157,7 @@ func (m *manager) OnDataSent(chid datatransfer.ChannelID, link ipld.Link, size u )) defer span.End() - _, err := m.channels.DataSent(chid, link.(cidlink.Link).Cid, size) + _, err := m.channels.DataSent(chid, link.(cidlink.Link).Cid, size, index, unique) return err } diff --git a/impl/initiating_test.go b/impl/initiating_test.go index 2bb4aef6..c5c25827 100644 --- a/impl/initiating_test.go +++ b/impl/initiating_test.go @@ -377,8 +377,8 @@ func TestDataTransferRestartInitiating(t *testing.T) { testCids := testutil.GenerateCids(2) ev, ok := h.dt.(datatransfer.EventsHandler) require.True(t, ok) - require.NoError(t, ev.OnDataReceived(channelID, cidlink.Link{Cid: testCids[0]}, 12345, 1)) - require.NoError(t, ev.OnDataReceived(channelID, cidlink.Link{Cid: testCids[1]}, 12345, 2)) + require.NoError(t, ev.OnDataReceived(channelID, cidlink.Link{Cid: testCids[0]}, 12345, 1, true)) + require.NoError(t, ev.OnDataReceived(channelID, cidlink.Link{Cid: testCids[1]}, 12345, 2, true)) // restart that pull channel err = h.dt.RestartDataTransferChannel(ctx, channelID) diff --git a/impl/responding_test.go b/impl/responding_test.go index 527c2c4b..90ef4b51 100644 --- a/impl/responding_test.go +++ b/impl/responding_test.go @@ -287,7 +287,7 @@ func TestDataTransferResponding(t *testing.T) { }, verify: func(t *testing.T, h *receiverHarness) { h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) - err := h.transport.EventHandler.OnDataReceived(channelID(h.id, h.peers), cidlink.Link{Cid: testutil.GenerateCids(1)[0]}, 12345, 1) + err := h.transport.EventHandler.OnDataReceived(channelID(h.id, h.peers), cidlink.Link{Cid: testutil.GenerateCids(1)[0]}, 12345, 1, true) require.EqualError(t, err, datatransfer.ErrPause.Error()) require.Len(t, h.network.SentMessages, 1) response, ok := h.network.SentMessages[0].Message.(datatransfer.Response) @@ -329,7 +329,7 @@ func TestDataTransferResponding(t *testing.T) { }, verify: func(t *testing.T, h *receiverHarness) { h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) - err := h.transport.EventHandler.OnDataReceived(channelID(h.id, h.peers), cidlink.Link{Cid: testutil.GenerateCids(1)[0]}, 12345, 1) + err := h.transport.EventHandler.OnDataReceived(channelID(h.id, h.peers), cidlink.Link{Cid: testutil.GenerateCids(1)[0]}, 12345, 1, true) require.Error(t, err) require.Len(t, h.network.SentMessages, 1) response, ok := h.network.SentMessages[0].Message.(datatransfer.Response) @@ -367,7 +367,7 @@ func TestDataTransferResponding(t *testing.T) { }, verify: func(t *testing.T, h *receiverHarness) { h.network.Delegate.ReceiveRequest(h.ctx, h.peers[1], h.pushRequest) - err := h.transport.EventHandler.OnDataReceived(channelID(h.id, h.peers), cidlink.Link{Cid: testutil.GenerateCids(1)[0]}, 12345, 1) + err := h.transport.EventHandler.OnDataReceived(channelID(h.id, h.peers), cidlink.Link{Cid: testutil.GenerateCids(1)[0]}, 12345, 1, true) require.EqualError(t, err, datatransfer.ErrPause.Error()) require.Len(t, h.network.SentMessages, 1) response, ok := h.network.SentMessages[0].Message.(datatransfer.Response) @@ -419,7 +419,7 @@ func TestDataTransferResponding(t *testing.T) { msg, err := h.transport.EventHandler.OnDataQueued( channelID(h.id, h.peers), cidlink.Link{Cid: testutil.GenerateCids(1)[0]}, - 12345) + 12345, 1, true) require.EqualError(t, err, datatransfer.ErrPause.Error()) response, ok := msg.(datatransfer.Response) require.True(t, ok) @@ -645,8 +645,8 @@ func TestDataTransferRestartResponding(t *testing.T) { testCids := testutil.GenerateCids(2) ev, ok := h.dt.(datatransfer.EventsHandler) require.True(t, ok) - require.NoError(t, ev.OnDataReceived(chid, cidlink.Link{Cid: testCids[0]}, 12345, 1)) - require.NoError(t, ev.OnDataReceived(chid, cidlink.Link{Cid: testCids[1]}, 12345, 2)) + require.NoError(t, ev.OnDataReceived(chid, cidlink.Link{Cid: testCids[0]}, 12345, 1, true)) + require.NoError(t, ev.OnDataReceived(chid, cidlink.Link{Cid: testCids[1]}, 12345, 2, true)) // receive restart push request req, err := message.NewRequest(h.pushRequest.TransferID(), true, false, h.voucher.Type(), h.voucher, @@ -857,8 +857,8 @@ func TestDataTransferRestartResponding(t *testing.T) { testCids := testutil.GenerateCids(2) ev, ok := h.dt.(datatransfer.EventsHandler) require.True(t, ok) - require.NoError(t, ev.OnDataReceived(channelID, cidlink.Link{Cid: testCids[0]}, 12345, 1)) - require.NoError(t, ev.OnDataReceived(channelID, cidlink.Link{Cid: testCids[1]}, 12345, 2)) + require.NoError(t, ev.OnDataReceived(channelID, cidlink.Link{Cid: testCids[0]}, 12345, 1, true)) + require.NoError(t, ev.OnDataReceived(channelID, cidlink.Link{Cid: testCids[1]}, 12345, 2, true)) // send a request to restart the same pull channel restartReq := message.RestartExistingChannelRequest(channelID) diff --git a/impl/restart_integration_test.go b/impl/restart_integration_test.go index 02556c6f..785af679 100644 --- a/impl/restart_integration_test.go +++ b/impl/restart_integration_test.go @@ -143,7 +143,6 @@ func TestRestartPush(t *testing.T) { queued := make(chan uint64, totalIncrements*2) sent := make(chan uint64, totalIncrements*2) received := make(chan uint64, totalIncrements*2) - var receivedCids []cid.Cid receivedTillNow := atomic.NewInt32(0) // counters we will check at the end for correctness @@ -182,15 +181,6 @@ func TestRestartPush(t *testing.T) { finishedPeersLk.Lock() { finishedPeers = append(finishedPeers, channelState.SelfPeer()) - - // When the receiving peer completes, record received CIDs - // before they get cleaned up - if channelState.SelfPeer() == rh.peer2 { - chs, err := rh.dt2.InProgressChannels(rh.testCtx) - require.NoError(t, err) - require.Len(t, chs, 1) - receivedCids = chs[chid].ReceivedCids() - } } finishedPeersLk.Unlock() finished <- channelState.SelfPeer() @@ -263,7 +253,6 @@ func TestRestartPush(t *testing.T) { require.NoError(t, err) // verify all cids are present on the receiver - require.Equal(t, totalIncrements, len(receivedCids)) testutil.VerifyHasFile(rh.testCtx, t, rh.destDagService, rh.root, rh.origBytes) rh.sv.VerifyExpectations(t) diff --git a/transport.go b/transport.go index dfd9d697..8719a9f9 100644 --- a/transport.go +++ b/transport.go @@ -23,7 +23,7 @@ type EventsHandler interface { // - nil = proceed with sending data // - error = cancel this request // - err == ErrPause - pause this request - OnDataReceived(chid ChannelID, link ipld.Link, size uint64, index int64) error + OnDataReceived(chid ChannelID, link ipld.Link, size uint64, index int64, unique bool) error // OnDataQueued is called when data is queued for sending for the given channel ID // return values are: @@ -32,10 +32,10 @@ type EventsHandler interface { // - nil = proceed with sending data // - error = cancel this request // - err == ErrPause - pause this request - OnDataQueued(chid ChannelID, link ipld.Link, size uint64) (Message, error) + OnDataQueued(chid ChannelID, link ipld.Link, size uint64, index int64, unique bool) (Message, error) // OnDataSent is called when we send data for the given channel ID - OnDataSent(chid ChannelID, link ipld.Link, size uint64) error + OnDataSent(chid ChannelID, link ipld.Link, size uint64, index int64, unique bool) error // OnTransferQueued is called when a new data transfer request is queued in the transport layer. OnTransferQueued(chid ChannelID) diff --git a/transport/graphsync/graphsync.go b/transport/graphsync/graphsync.go index d052a040..8ffaa03b 100644 --- a/transport/graphsync/graphsync.go +++ b/transport/graphsync/graphsync.go @@ -527,7 +527,7 @@ func (t *Transport) gsIncomingBlockHook(p peer.ID, response graphsync.ResponseDa return } - err := t.events.OnDataReceived(chid, block.Link(), block.BlockSize(), block.Index()) + err := t.events.OnDataReceived(chid, block.Link(), block.BlockSize(), block.Index(), block.BlockSizeOnWire() != 0) if err != nil && err != datatransfer.ErrPause { hookActions.TerminateWithError(err) return @@ -553,7 +553,7 @@ func (t *Transport) gsBlockSentHook(p peer.ID, request graphsync.RequestData, bl return } - if err := t.events.OnDataSent(chid, block.Link(), block.BlockSize()); err != nil { + if err := t.events.OnDataSent(chid, block.Link(), block.BlockSize(), block.Index(), block.BlockSizeOnWire() != 0); err != nil { log.Errorf("failed to process data sent: %+v", err) } } @@ -577,7 +577,7 @@ func (t *Transport) gsOutgoingBlockHook(p peer.ID, request graphsync.RequestData // peer. It can return ErrPause to pause the response (eg if payment is // required) and it can return a message that will be sent with the block // (eg to ask for payment). - msg, err := t.events.OnDataQueued(chid, block.Link(), block.BlockSize()) + msg, err := t.events.OnDataQueued(chid, block.Link(), block.BlockSize(), block.Index(), block.BlockSizeOnWire() != 0) if err != nil && err != datatransfer.ErrPause { hookActions.TerminateWithError(err) return diff --git a/transport/graphsync/graphsync_test.go b/transport/graphsync/graphsync_test.go index 7e6943a5..16dbb2df 100644 --- a/transport/graphsync/graphsync_test.go +++ b/transport/graphsync/graphsync_test.go @@ -1199,7 +1199,7 @@ type fakeEvents struct { ResponseReceivedResponse datatransfer.Response } -func (fe *fakeEvents) OnDataQueued(chid datatransfer.ChannelID, link ipld.Link, size uint64) (datatransfer.Message, error) { +func (fe *fakeEvents) OnDataQueued(chid datatransfer.ChannelID, link ipld.Link, size uint64, index int64, unique bool) (datatransfer.Message, error) { fe.OnDataQueuedCalled = true return fe.OnDataQueuedMessage, fe.OnDataQueuedError @@ -1238,12 +1238,12 @@ func (fe *fakeEvents) OnChannelOpened(chid datatransfer.ChannelID) error { return fe.OnChannelOpenedError } -func (fe *fakeEvents) OnDataReceived(chid datatransfer.ChannelID, link ipld.Link, size uint64, index int64) error { +func (fe *fakeEvents) OnDataReceived(chid datatransfer.ChannelID, link ipld.Link, size uint64, index int64, unique bool) error { fe.OnDataReceivedCalled = true return fe.OnDataReceivedError } -func (fe *fakeEvents) OnDataSent(chid datatransfer.ChannelID, link ipld.Link, size uint64) error { +func (fe *fakeEvents) OnDataSent(chid datatransfer.ChannelID, link ipld.Link, size uint64, index int64, unique bool) error { fe.OnDataSentCalled = true return nil } @@ -1458,6 +1458,14 @@ func (m *mockChannelState) ReceivedCidsTotal() int64 { return (int64)(len(m.receivedCids)) } +func (m *mockChannelState) QueuedCidsTotal() int64 { + panic("implement me") +} + +func (m *mockChannelState) SentCidsTotal() int64 { + panic("implement me") +} + func (m *mockChannelState) Queued() uint64 { panic("implement me") } diff --git a/types.go b/types.go index ac35f2da..983d8435 100644 --- a/types.go +++ b/types.go @@ -139,6 +139,14 @@ type ChannelState interface { // on the channel - note that a block can exist in more than one place in the DAG ReceivedCidsTotal() int64 + // QueuedCidsTotal returns the number of (non-unique) cids queued so far + // on the channel - note that a block can exist in more than one place in the DAG + QueuedCidsTotal() int64 + + // SentCidsTotal returns the number of (non-unique) cids sent so far + // on the channel - note that a block can exist in more than one place in the DAG + SentCidsTotal() int64 + // Queued returns the number of bytes read from the node and queued for sending Queued() uint64