diff --git a/channels/channels_test.go b/channels/channels_test.go index 74fc45d7..a93bc08c 100644 --- a/channels/channels_test.go +++ b/channels/channels_test.go @@ -11,27 +11,13 @@ import ( "github.com/stretchr/testify/require" ) -type fakeVoucher struct{} - -func (fv *fakeVoucher) ToBytes() ([]byte, error) { - panic("not implemented") -} - -func (fv *fakeVoucher) FromBytes(_ []byte) error { - panic("not implemented") -} - -func (fv *fakeVoucher) Type() string { - panic("not implemented") -} - func TestChannels(t *testing.T) { channels := channels.New() tid1 := datatransfer.TransferID(0) tid2 := datatransfer.TransferID(1) - fv1 := &fakeVoucher{} - fv2 := &fakeVoucher{} + fv1 := &testutil.FakeDTType{} + fv2 := &testutil.FakeDTType{} cids := testutil.GenerateCids(2) selector := builder.NewSelectorSpecBuilder(basicnode.Style.Any).Matcher().Node() peers := testutil.GeneratePeers(4) diff --git a/encoding/encoding.go b/encoding/encoding.go new file mode 100644 index 00000000..7cca7d73 --- /dev/null +++ b/encoding/encoding.go @@ -0,0 +1,122 @@ +package encoding + +import ( + "bytes" + "reflect" + + cbor "github.com/ipfs/go-ipld-cbor" + "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/codec/dagcbor" + cborgen "github.com/whyrusleeping/cbor-gen" + "golang.org/x/xerrors" +) + +// Encodable is an object that can be written to CBOR and decoded back +type Encodable interface{} + +// Encode encodes an encodable to CBOR, using the best available path for +// writing to CBOR +func Encode(value Encodable) ([]byte, error) { + if cbgEncodable, ok := value.(cborgen.CBORMarshaler); ok { + buf := new(bytes.Buffer) + err := cbgEncodable.MarshalCBOR(buf) + if err != nil { + return nil, err + } + return buf.Bytes(), nil + } + if ipldEncodable, ok := value.(ipld.Node); ok { + buf := new(bytes.Buffer) + err := dagcbor.Encoder(ipldEncodable, buf) + if err != nil { + return nil, err + } + return buf.Bytes(), nil + } + return cbor.DumpObject(value) +} + +// Decoder is CBOR decoder for a given encodable type +type Decoder interface { + DecodeFromCbor([]byte) (Encodable, error) +} + +// NewDecoder creates a new Decoder that will decode into new instances of the given +// object type. It will use the decoding that is optimal for that type +// It returns error if it's not possible to setup a decoder for this type +func NewDecoder(decodeType Encodable) (Decoder, error) { + // check if type is ipld.Node, if so, just use style + if ipldDecodable, ok := decodeType.(ipld.Node); ok { + return &ipldDecoder{ipldDecodable.Style()}, nil + } + // check if type is a pointer, as we need that to make new copies + // for cborgen types & regular IPLD types + decodeReflectType := reflect.TypeOf(decodeType) + if decodeReflectType.Kind() != reflect.Ptr { + return nil, xerrors.New("type must be a pointer") + } + // check if type is a cbor-gen type + if _, ok := decodeType.(cborgen.CBORUnmarshaler); ok { + return &cbgDecoder{decodeReflectType}, nil + } + // type does is neither ipld-prime nor cbor-gen, so we need to see if it + // can rountrip with oldschool ipld-format + encoded, err := cbor.DumpObject(decodeType) + if err != nil { + return nil, xerrors.New("Object type did not encode") + } + newDecodable := reflect.New(decodeReflectType.Elem()).Interface() + if err := cbor.DecodeInto(encoded, newDecodable); err != nil { + return nil, xerrors.New("Object type did not decode") + } + return &defaultDecoder{decodeReflectType}, nil +} + +type ipldDecoder struct { + style ipld.NodeStyle +} + +func (decoder *ipldDecoder) DecodeFromCbor(encoded []byte) (Encodable, error) { + builder := decoder.style.NewBuilder() + buf := bytes.NewReader(encoded) + err := dagcbor.Decoder(builder, buf) + if err != nil { + return nil, err + } + return builder.Build(), nil +} + +type cbgDecoder struct { + cbgType reflect.Type +} + +func (decoder *cbgDecoder) DecodeFromCbor(encoded []byte) (Encodable, error) { + decodedValue := reflect.New(decoder.cbgType.Elem()) + decoded, ok := decodedValue.Interface().(cborgen.CBORUnmarshaler) + if !ok || reflect.ValueOf(decoded).IsNil() { + return nil, xerrors.New("problem instantiating decoded value") + } + buf := bytes.NewReader(encoded) + err := decoded.UnmarshalCBOR(buf) + if err != nil { + return nil, err + } + return decoded, nil +} + +type defaultDecoder struct { + ptrType reflect.Type +} + +func (decoder *defaultDecoder) DecodeFromCbor(encoded []byte) (Encodable, error) { + decodedValue := reflect.New(decoder.ptrType.Elem()) + decoded, ok := decodedValue.Interface().(Encodable) + if !ok || reflect.ValueOf(decoded).IsNil() { + return nil, xerrors.New("problem instantiating decoded value") + } + err := cbor.DecodeInto(encoded, decoded) + if err != nil { + return nil, err + } + return decoded, nil +} diff --git a/encoding/encoding_test.go b/encoding/encoding_test.go new file mode 100644 index 00000000..43a66f5d --- /dev/null +++ b/encoding/encoding_test.go @@ -0,0 +1,37 @@ +package encoding_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/filecoin-project/go-data-transfer/encoding" + "github.com/filecoin-project/go-data-transfer/encoding/testdata" +) + +func TestRoundTrip(t *testing.T) { + testCases := map[string]struct { + val encoding.Encodable + }{ + "can encode/decode IPLD prime types": { + val: testdata.Prime, + }, + "can encode/decode cbor-gen types": { + val: testdata.Cbg, + }, + "can encode/decode old ipld format types": { + val: testdata.Standard, + }, + } + for testCase, data := range testCases { + t.Run(testCase, func(t *testing.T) { + encoded, err := encoding.Encode(data.val) + require.NoError(t, err) + decoder, err := encoding.NewDecoder(data.val) + require.NoError(t, err) + decoded, err := decoder.DecodeFromCbor(encoded) + require.NoError(t, err) + require.Equal(t, data.val, decoded) + }) + } +} diff --git a/encoding/testdata/testdata.go b/encoding/testdata/testdata.go new file mode 100644 index 00000000..c76abc24 --- /dev/null +++ b/encoding/testdata/testdata.go @@ -0,0 +1,37 @@ +package testdata + +import ( + cbor "github.com/ipfs/go-ipld-cbor" + "github.com/ipld/go-ipld-prime/fluent" + basicnode "github.com/ipld/go-ipld-prime/node/basic" +) + +// Prime = an instance of an ipld prime piece of data +var Prime = fluent.MustBuildMap(basicnode.Style.Map, 2, func(na fluent.MapAssembler) { + nva := na.AssembleEntry("X") + nva.AssignInt(100) + nva = na.AssembleEntry("Y") + nva.AssignString("appleSauce") +}) + +type standardType struct { + X int + Y string +} + +func init() { + cbor.RegisterCborType(standardType{}) +} + +// Standard = an instance that is neither ipld prime nor cbor +var Standard *standardType = &standardType{X: 100, Y: "appleSauce"} + +//go:generate cbor-gen-for cbgType + +type cbgType struct { + X uint64 + Y string +} + +// Cbg = an instance of a cbor-gen type +var Cbg *cbgType = &cbgType{X: 100, Y: "appleSauce"} diff --git a/encoding/testdata/testdata_cbor_gen.go b/encoding/testdata/testdata_cbor_gen.go new file mode 100644 index 00000000..67c6c688 --- /dev/null +++ b/encoding/testdata/testdata_cbor_gen.go @@ -0,0 +1,84 @@ +// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. + +package testdata + +import ( + "fmt" + "io" + + cbg "github.com/whyrusleeping/cbor-gen" + xerrors "golang.org/x/xerrors" +) + +var _ = xerrors.Errorf + +func (t *cbgType) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + if _, err := w.Write([]byte{130}); err != nil { + return err + } + + // t.X (uint64) (uint64) + + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajUnsignedInt, uint64(t.X))); err != nil { + return err + } + + // t.Y (string) (string) + if len(t.Y) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.Y was too long") + } + + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len(t.Y)))); err != nil { + return err + } + if _, err := w.Write([]byte(t.Y)); err != nil { + return err + } + return nil +} + +func (t *cbgType) UnmarshalCBOR(r io.Reader) error { + br := cbg.GetPeeker(r) + + maj, extra, err := cbg.CborReadHeader(br) + if err != nil { + return err + } + if maj != cbg.MajArray { + return fmt.Errorf("cbor input should be of type array") + } + + if extra != 2 { + return fmt.Errorf("cbor input had wrong number of fields") + } + + // t.X (uint64) (uint64) + + { + + maj, extra, err = cbg.CborReadHeader(br) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.X = uint64(extra) + + } + // t.Y (string) (string) + + { + sval, err := cbg.ReadString(br) + if err != nil { + return err + } + + t.Y = string(sval) + } + return nil +} diff --git a/go.mod b/go.mod index dc5bebc5..00cf65c6 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/ipfs/go-ipfs-chunker v0.0.5 github.com/ipfs/go-ipfs-exchange-offline v0.0.1 github.com/ipfs/go-ipfs-files v0.0.8 + github.com/ipfs/go-ipld-cbor v0.0.4 github.com/ipfs/go-ipld-format v0.2.0 github.com/ipfs/go-log v1.0.2 github.com/ipfs/go-merkledag v0.3.1 diff --git a/impl/dagservice/dagservice.go b/impl/dagservice/dagservice.go index 2c9ce797..dcb383b4 100644 --- a/impl/dagservice/dagservice.go +++ b/impl/dagservice/dagservice.go @@ -2,7 +2,6 @@ package datatransfer import ( "context" - "reflect" "time" "github.com/ipfs/go-cid" @@ -12,7 +11,7 @@ import ( "github.com/libp2p/go-libp2p-core/peer" "golang.org/x/xerrors" - "github.com/filecoin-project/go-data-transfer" + datatransfer "github.com/filecoin-project/go-data-transfer" ) // This file implements a VERY simple, incomplete version of the data transfer @@ -38,7 +37,7 @@ func NewDAGServiceDataTransfer(dag ipldformat.DAGService) datatransfer.Manager { // RegisterVoucherType registers a validator for the given voucher type // will error if voucher type does not implement voucher // or if there is a voucher type registered with an identical identifier -func (impl *dagserviceImpl) RegisterVoucherType(voucherType reflect.Type, validator datatransfer.RequestValidator) error { +func (impl *dagserviceImpl) RegisterVoucherType(voucherType datatransfer.Voucher, validator datatransfer.RequestValidator) error { return nil } diff --git a/impl/graphsync/graphsync_impl.go b/impl/graphsync/graphsync_impl.go index cb6e1dce..4582a5d2 100644 --- a/impl/graphsync/graphsync_impl.go +++ b/impl/graphsync/graphsync_impl.go @@ -4,25 +4,28 @@ import ( "bytes" "context" "errors" - "fmt" - "reflect" "time" "github.com/ipfs/go-cid" "github.com/ipfs/go-graphsync" + logging "github.com/ipfs/go-log" "github.com/ipld/go-ipld-prime" cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/peer" + "golang.org/x/xerrors" datatransfer "github.com/filecoin-project/go-data-transfer" "github.com/filecoin-project/go-data-transfer/channels" "github.com/filecoin-project/go-data-transfer/message" "github.com/filecoin-project/go-data-transfer/network" + "github.com/filecoin-project/go-data-transfer/registry" "github.com/filecoin-project/go-storedcounter" "github.com/hannahhoward/go-pubsub" ) +var log = logging.Logger("graphsync-impl") + // This file implements a VERY simple, incomplete version of the data transfer // module that allows us to make the necessary insertions of data transfer // functionality into the storage market @@ -30,14 +33,9 @@ import ( // -- support multiple subscribers // -- do any actual network coordination or use Graphsync -type validateType struct { - voucherType reflect.Type // nolint: structcheck - validator datatransfer.RequestValidator // nolint: structcheck -} - type graphsyncImpl struct { dataTransferNetwork network.DataTransferNetwork - validatedTypes map[string]validateType + validatedTypes *registry.Registry pubSub *pubsub.PubSub channels *channels.Channels gs graphsync.GraphExchange @@ -68,7 +66,7 @@ func NewGraphSyncDataTransfer(host host.Host, gs graphsync.GraphExchange, stored dataTransferNetwork := network.NewFromLibp2pHost(host) impl := &graphsyncImpl{ dataTransferNetwork, - make(map[string]validateType), + registry.NewRegistry(), pubsub.New(dispatcher), channels.New(), gs, @@ -147,24 +145,10 @@ func (impl *graphsyncImpl) gsCompletedResponseListener(p peer.ID, request graphs // * voucher type does not implement voucher // * there is a voucher type registered with an identical identifier // * voucherType's Kind is not reflect.Ptr -func (impl *graphsyncImpl) RegisterVoucherType(voucherType reflect.Type, validator datatransfer.RequestValidator) error { - if voucherType.Kind() != reflect.Ptr { - return fmt.Errorf("voucherType must be a reflect.Ptr Kind") - } - v := reflect.New(voucherType.Elem()) - voucher, ok := v.Interface().(datatransfer.Voucher) - if !ok { - return fmt.Errorf("voucher does not implement Voucher interface") - } - - _, isReg := impl.validatedTypes[voucher.Type()] - if isReg { - return fmt.Errorf("voucher type already registered: %s", voucherType.String()) - } - - impl.validatedTypes[voucher.Type()] = validateType{ - voucherType: voucherType, - validator: validator, +func (impl *graphsyncImpl) RegisterVoucherType(voucherType datatransfer.Voucher, validator datatransfer.RequestValidator) error { + err := impl.validatedTypes.Register(voucherType, validator) + if err != nil { + return xerrors.Errorf("error registering voucher type: %w", err) } return nil } @@ -204,21 +188,15 @@ func (impl *graphsyncImpl) OpenPullDataChannel(ctx context.Context, requestTo pe // sendDtRequest encapsulates message creation and posting to the data transfer network with the provided parameters func (impl *graphsyncImpl) sendDtRequest(ctx context.Context, selector ipld.Node, isPull bool, voucher datatransfer.Voucher, baseCid cid.Cid, to peer.ID) (datatransfer.TransferID, error) { - sbytes, err := nodeAsBytes(selector) - if err != nil { - return 0, err - } - vbytes, err := voucher.ToBytes() + next, err := impl.storedCounter.Next() if err != nil { return 0, err } - next, err := impl.storedCounter.Next() + tid := datatransfer.TransferID(next) + req, err := message.NewRequest(tid, isPull, voucher.Type(), voucher, baseCid, selector) if err != nil { return 0, err } - tid := datatransfer.TransferID(next) - req := message.NewRequest(tid, isPull, voucher.Type(), vbytes, baseCid, sbytes) - if err := impl.dataTransferNetwork.SendMessage(ctx, to, req); err != nil { return 0, err } diff --git a/impl/graphsync/graphsync_impl_test.go b/impl/graphsync/graphsync_impl_test.go index 740adef3..77c1593e 100644 --- a/impl/graphsync/graphsync_impl_test.go +++ b/impl/graphsync/graphsync_impl_test.go @@ -5,7 +5,6 @@ import ( "context" "errors" "math/rand" - "reflect" "testing" "time" @@ -13,7 +12,6 @@ import ( "github.com/ipfs/go-graphsync" gsmsg "github.com/ipfs/go-graphsync/message" "github.com/ipld/go-ipld-prime" - "github.com/ipld/go-ipld-prime/codec/dagcbor" cidlink "github.com/ipld/go-ipld-prime/linking/cid" basicnode "github.com/ipld/go-ipld-prime/node/basic" "github.com/ipld/go-ipld-prime/traversal/selector" @@ -64,23 +62,6 @@ func (r *receiver) ReceiveResponse( func (r *receiver) ReceiveError(err error) { } -type fakeDTType struct { - data string -} - -func (ft *fakeDTType) ToBytes() ([]byte, error) { - return []byte(ft.data), nil -} - -func (ft *fakeDTType) FromBytes(data []byte) error { - ft.data = string(data) - return nil -} - -func (ft *fakeDTType) Type() string { - return "FakeDTType" -} - func TestDataTransferOneWay(t *testing.T) { // create network ctx := context.Background() @@ -107,9 +88,9 @@ func TestDataTransferOneWay(t *testing.T) { stor := ssb.ExploreRecursive(selector.RecursionLimitNone(), ssb.ExploreAll(ssb.ExploreRecursiveEdge())).Node() - voucher := fakeDTType{"applesauce"} + voucher := testutil.NewFakeDTType() baseCid := testutil.GenerateCids(1)[0] - channelID, err := dt.OpenPushDataChannel(ctx, host2.ID(), &voucher, baseCid, stor) + channelID, err := dt.OpenPushDataChannel(ctx, host2.ID(), voucher, baseCid, stor) require.NoError(t, err) require.NotNil(t, channelID) require.Equal(t, channelID.Initiator, host1.ID()) @@ -134,17 +115,10 @@ func TestDataTransferOneWay(t *testing.T) { require.Equal(t, receivedRequest.BaseCid(), baseCid) require.False(t, receivedRequest.IsCancel()) require.False(t, receivedRequest.IsPull()) - reader := bytes.NewReader(receivedRequest.Selector()) - nb := basicnode.Style.Any.NewBuilder() - err = dagcbor.Decoder(nb, reader) + receivedSelector, err := receivedRequest.Selector() require.NoError(t, err) - receivedSelector := nb.Build() require.Equal(t, receivedSelector, stor) - receivedVoucher := new(fakeDTType) - err = receivedVoucher.FromBytes(receivedRequest.Voucher()) - require.NoError(t, err) - require.Equal(t, *receivedVoucher, voucher) - require.Equal(t, receivedRequest.VoucherType(), voucher.Type()) + testutil.AssertFakeDTVoucher(t, receivedRequest, voucher) }) t.Run("OpenPullDataTransfer", func(t *testing.T) { @@ -153,9 +127,9 @@ func TestDataTransferOneWay(t *testing.T) { stor := ssb.ExploreRecursive(selector.RecursionLimitNone(), ssb.ExploreAll(ssb.ExploreRecursiveEdge())).Node() - voucher := fakeDTType{"applesauce"} + voucher := testutil.NewFakeDTType() baseCid := testutil.GenerateCids(1)[0] - channelID, err := dt.OpenPullDataChannel(ctx, host2.ID(), &voucher, baseCid, stor) + channelID, err := dt.OpenPullDataChannel(ctx, host2.ID(), voucher, baseCid, stor) require.NoError(t, err) require.NotNil(t, channelID) require.Equal(t, channelID.Initiator, host1.ID()) @@ -180,17 +154,10 @@ func TestDataTransferOneWay(t *testing.T) { require.Equal(t, receivedRequest.BaseCid(), baseCid) require.False(t, receivedRequest.IsCancel()) require.True(t, receivedRequest.IsPull()) - reader := bytes.NewReader(receivedRequest.Selector()) - nb := basicnode.Style.Any.NewBuilder() - err = dagcbor.Decoder(nb, reader) + receivedSelector, err := receivedRequest.Selector() require.NoError(t, err) - receivedSelector := nb.Build() require.Equal(t, receivedSelector, stor) - receivedVoucher := new(fakeDTType) - err = receivedVoucher.FromBytes(receivedRequest.Voucher()) - require.NoError(t, err) - require.Equal(t, *receivedVoucher, voucher) - require.Equal(t, receivedRequest.VoucherType(), voucher.Type()) + testutil.AssertFakeDTVoucher(t, receivedRequest, voucher) }) } @@ -252,15 +219,13 @@ func TestDataTransferValidation(t *testing.T) { fv := &fakeValidator{ctx, make(chan receivedValidation)} id := datatransfer.TransferID(rand.Int31()) - var buffer bytes.Buffer - require.NoError(t, dagcbor.Encoder(gsData.AllSelector, &buffer)) t.Run("ValidatePush", func(t *testing.T) { dt2 := NewGraphSyncDataTransfer(host2, gs2, gsData.StoredCounter2) - err := dt2.RegisterVoucherType(reflect.TypeOf(&fakeDTType{}), fv) + err := dt2.RegisterVoucherType(&testutil.FakeDTType{}, fv) require.NoError(t, err) // create push request - voucher, baseCid, request := createDTRequest(t, false, id, buffer.Bytes()) + voucher, baseCid, request := createDTRequest(t, false, id, gsData.AllSelector) require.NoError(t, dtnet1.SendMessage(ctx, host2.ID(), request)) @@ -287,7 +252,7 @@ func TestDataTransferValidation(t *testing.T) { t.Run("ValidatePull", func(t *testing.T) { // create pull request - voucher, baseCid, request := createDTRequest(t, true, id, buffer.Bytes()) + voucher, baseCid, request := createDTRequest(t, true, id, gsData.AllSelector) require.NoError(t, dtnet1.SendMessage(ctx, host2.ID(), request)) var validation receivedValidation @@ -310,13 +275,12 @@ func TestDataTransferValidation(t *testing.T) { }) } -func createDTRequest(t *testing.T, isPull bool, id datatransfer.TransferID, selectorBytes []byte) (fakeDTType, cid.Cid, message.DataTransferRequest) { - voucher := fakeDTType{"applesauce"} +func createDTRequest(t *testing.T, isPull bool, id datatransfer.TransferID, selector ipld.Node) (testutil.FakeDTType, cid.Cid, message.DataTransferRequest) { + voucher := &testutil.FakeDTType{Data: "applesauce"} baseCid := testutil.GenerateCids(1)[0] - voucherBytes, err := voucher.ToBytes() + request, err := message.NewRequest(id, isPull, voucher.Type(), voucher, baseCid, selector) require.NoError(t, err) - request := message.NewRequest(id, isPull, voucher.Type(), voucherBytes, baseCid, selectorBytes) - return voucher, baseCid, request + return *voucher, baseCid, request } type stubbedValidator struct { @@ -407,14 +371,14 @@ func TestGraphsyncImpl_RegisterVoucherType(t *testing.T) { fv := &fakeValidator{ctx, make(chan receivedValidation)} // a voucher type can be registered - assert.NoError(t, dt.RegisterVoucherType(reflect.TypeOf(&fakeDTType{}), fv)) + assert.NoError(t, dt.RegisterVoucherType(&testutil.FakeDTType{}, fv)) // it cannot be re-registered - assert.EqualError(t, dt.RegisterVoucherType(reflect.TypeOf(&fakeDTType{}), fv), "voucher type already registered: *graphsyncimpl_test.fakeDTType") + assert.EqualError(t, dt.RegisterVoucherType(&testutil.FakeDTType{}, fv), "error registering voucher type: identifier already registered: FakeDTType") // it must be registered as a pointer - assert.EqualError(t, dt.RegisterVoucherType(reflect.TypeOf(fakeDTType{}), fv), - "voucherType must be a reflect.Ptr Kind") + assert.EqualError(t, dt.RegisterVoucherType(testutil.FakeDTType{}, fv), + "error registering voucher type: registering entry type FakeDTType: type must be a pointer") } func TestDataTransferSubscribing(t *testing.T) { @@ -432,8 +396,8 @@ func TestDataTransferSubscribing(t *testing.T) { sv.stubErrorPull() sv.stubErrorPush() dt2 := NewGraphSyncDataTransfer(host2, gs2, gsData.StoredCounter2) - require.NoError(t, dt2.RegisterVoucherType(reflect.TypeOf(&fakeDTType{}), sv)) - voucher := fakeDTType{"applesauce"} + require.NoError(t, dt2.RegisterVoucherType(&testutil.FakeDTType{}, sv)) + voucher := testutil.FakeDTType{Data: "applesauce"} baseCid := testutil.GenerateCids(1)[0] dt1 := NewGraphSyncDataTransfer(host1, gs1, gsData.StoredCounter1) @@ -524,19 +488,15 @@ func TestDataTransferInitiatingPushGraphsyncRequests(t *testing.T) { dtnet1.SetDelegate(r) id := datatransfer.TransferID(rand.Int31()) - var buffer bytes.Buffer - err := dagcbor.Encoder(gsData.AllSelector, &buffer) - require.NoError(t, err) - - _, baseCid, request := createDTRequest(t, false, id, buffer.Bytes()) + _, baseCid, request := createDTRequest(t, false, id, gsData.AllSelector) t.Run("with successful validation", func(t *testing.T) { sv := newSV() sv.expectSuccessPush() dt2 := NewGraphSyncDataTransfer(host2, gs2, gsData.StoredCounter2) - require.NoError(t, dt2.RegisterVoucherType(reflect.TypeOf(&fakeDTType{}), sv)) + require.NoError(t, dt2.RegisterVoucherType(&testutil.FakeDTType{}, sv)) require.NoError(t, dtnet1.SendMessage(ctx, host2.ID(), request)) select { @@ -566,7 +526,7 @@ func TestDataTransferInitiatingPushGraphsyncRequests(t *testing.T) { sv.expectErrorPush() dt2 := NewGraphSyncDataTransfer(host2, gs2, gsData.StoredCounter2) - require.NoError(t, dt2.RegisterVoucherType(reflect.TypeOf(&fakeDTType{}), sv)) + require.NoError(t, dt2.RegisterVoucherType(&testutil.FakeDTType{}, sv)) require.NoError(t, dtnet1.SendMessage(ctx, host2.ID(), request)) select { @@ -589,7 +549,7 @@ func TestDataTransferInitiatingPullGraphsyncRequests(t *testing.T) { host1 := gsData.Host1 // initiates the pull request host2 := gsData.Host2 // sends the data - voucher := fakeDTType{"applesauce"} + voucher := testutil.FakeDTType{Data: "applesauce"} baseCid := testutil.GenerateCids(1)[0] t.Run("with successful validation", func(t *testing.T) { @@ -604,7 +564,7 @@ func TestDataTransferInitiatingPullGraphsyncRequests(t *testing.T) { dtInit := NewGraphSyncDataTransfer(host1, gs1Init, gsData.StoredCounter1) dtSender := NewGraphSyncDataTransfer(host2, gs2Sender, gsData.StoredCounter2) - err := dtSender.RegisterVoucherType(reflect.TypeOf(&fakeDTType{}), sv) + err := dtSender.RegisterVoucherType(&testutil.FakeDTType{}, sv) require.NoError(t, err) _, err = dtInit.OpenPullDataChannel(ctx, host2.ID(), &voucher, baseCid, gsData.AllSelector) @@ -632,7 +592,7 @@ func TestDataTransferInitiatingPullGraphsyncRequests(t *testing.T) { sv.expectErrorPull() dt2 := NewGraphSyncDataTransfer(host2, gs2, gsData.StoredCounter2) - err := dt2.RegisterVoucherType(reflect.TypeOf(&fakeDTType{}), sv) + err := dt2.RegisterVoucherType(&testutil.FakeDTType{}, sv) require.NoError(t, err) subscribeCalls := make(chan struct{}, 1) @@ -670,7 +630,7 @@ func TestDataTransferInitiatingPullGraphsyncRequests(t *testing.T) { dt1 := NewGraphSyncDataTransfer(host1, gs1, gsData.StoredCounter1) dt2 := NewGraphSyncDataTransfer(host2, gs2, gsData.StoredCounter2) - err := dt2.RegisterVoucherType(reflect.TypeOf(&fakeDTType{}), sv) + err := dt2.RegisterVoucherType(&testutil.FakeDTType{}, sv) require.NoError(t, err) subscribeCalls := make(chan struct{}, 1) @@ -742,7 +702,7 @@ func TestRespondingToPushGraphsyncRequests(t *testing.T) { gsData := testutil.NewGraphsyncTestingData(ctx, t) host1 := gsData.Host1 // initiator and data sender host2 := gsData.Host2 // data recipient, makes graphsync request for data - voucher := fakeDTType{"applesauce"} + voucher := testutil.FakeDTType{Data: "applesauce"} link := gsData.LoadUnixFSFile(t, false) // setup receiving peer to just record message coming in @@ -818,7 +778,7 @@ func TestResponseHookWhenExtensionNotFound(t *testing.T) { gsData := testutil.NewGraphsyncTestingData(ctx, t) host1 := gsData.Host1 // initiator and data sender host2 := gsData.Host2 // data recipient, makes graphsync request for data - voucher := fakeDTType{"applesauce"} + voucher := testutil.FakeDTType{Data: "applesauce"} link := gsData.LoadUnixFSFile(t, false) // setup receiving peer to just record message coming in @@ -890,19 +850,15 @@ func TestRespondingToPullGraphsyncRequests(t *testing.T) { link := gsData.LoadUnixFSFile(t, true) id := datatransfer.TransferID(rand.Int31()) - var buf bytes.Buffer - err := dagcbor.Encoder(gsData.AllSelector, &buf) - require.NoError(t, err) - selectorBytes := buf.Bytes() t.Run("When a pull request is initiated and validated", func(t *testing.T) { sv := newSV() sv.expectSuccessPull() dt1 := NewGraphSyncDataTransfer(host2, gs2, gsData.StoredCounter2) - require.NoError(t, dt1.RegisterVoucherType(reflect.TypeOf(&fakeDTType{}), sv)) + require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) - _, _, request := createDTRequest(t, true, id, selectorBytes) + _, _, request := createDTRequest(t, true, id, gsData.AllSelector) require.NoError(t, dtnet1.SendMessage(ctx, host2.ID(), request)) var messageReceived receivedMessage select { @@ -921,7 +877,7 @@ func TestRespondingToPullGraphsyncRequests(t *testing.T) { } var buf2 = bytes.Buffer{} - err = extStruct.MarshalCBOR(&buf2) + err := extStruct.MarshalCBOR(&buf2) require.NoError(t, err) extData := buf2.Bytes() @@ -943,7 +899,7 @@ func TestRespondingToPullGraphsyncRequests(t *testing.T) { extStruct := &ExtensionDataTransferData{TransferID: rand.Uint64(), Initiator: host1.ID()} var buf2 bytes.Buffer - err = extStruct.MarshalCBOR(&buf2) + err := extStruct.MarshalCBOR(&buf2) require.NoError(t, err) extData := buf2.Bytes() request := gsmsg.NewRequest(graphsync.RequestID(rand.Int31()), link.(cidlink.Link).Cid, gsData.AllSelector, graphsync.Priority(rand.Int31()), graphsync.ExtensionData{ @@ -986,10 +942,10 @@ func TestDataTransferPushRoundTrip(t *testing.T) { } dt1.SubscribeToEvents(subscriber) dt2.SubscribeToEvents(subscriber) - voucher := fakeDTType{"applesauce"} + voucher := testutil.FakeDTType{Data: "applesauce"} sv := newSV() sv.expectSuccessPull() - require.NoError(t, dt2.RegisterVoucherType(reflect.TypeOf(&fakeDTType{}), sv)) + require.NoError(t, dt2.RegisterVoucherType(&testutil.FakeDTType{}, sv)) chid, err := dt1.OpenPushDataChannel(ctx, host2.ID(), &voucher, rootCid, gsData.AllSelector) require.NoError(t, err) @@ -1029,10 +985,10 @@ func TestDataTransferPullRoundTrip(t *testing.T) { } dt1.SubscribeToEvents(subscriber) dt2.SubscribeToEvents(subscriber) - voucher := fakeDTType{"applesauce"} + voucher := testutil.FakeDTType{Data: "applesauce"} sv := newSV() sv.expectSuccessPull() - require.NoError(t, dt1.RegisterVoucherType(reflect.TypeOf(&fakeDTType{}), sv)) + require.NoError(t, dt1.RegisterVoucherType(&testutil.FakeDTType{}, sv)) _, err := dt2.OpenPullDataChannel(ctx, host1.ID(), &voucher, rootCid, gsData.AllSelector) require.NoError(t, err) diff --git a/impl/graphsync/graphsync_receiver.go b/impl/graphsync/graphsync_receiver.go index d400c3ec..449f841c 100644 --- a/impl/graphsync/graphsync_receiver.go +++ b/impl/graphsync/graphsync_receiver.go @@ -2,14 +2,13 @@ package graphsyncimpl import ( "context" - "fmt" - "reflect" "time" "github.com/ipfs/go-cid" "github.com/ipld/go-ipld-prime" cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/libp2p/go-libp2p-core/peer" + xerrors "golang.org/x/xerrors" datatransfer "github.com/filecoin-project/go-data-transfer" "github.com/filecoin-project/go-data-transfer/message" @@ -31,7 +30,7 @@ func (receiver *graphsyncReceiver) ReceiveRequest( receiver.impl.sendResponse(ctx, false, initiator, incoming.TransferID()) return } - stor, _ := nodeFromBytes(incoming.Selector()) + stor, _ := incoming.Selector() root := cidlink.Link{Cid: incoming.BaseCid()} var dataSender, dataReceiver peer.ID @@ -56,25 +55,32 @@ func (receiver *graphsyncReceiver) ReceiveRequest( // validateVoucher converts a voucher in an incoming message to its appropriate // voucher struct, then runs the validator and returns the results. // returns error if: -// * voucherFromRequest fails +// * reading voucher fails // * deserialization of selector fails // * validation fails func (receiver *graphsyncReceiver) validateVoucher(sender peer.ID, incoming message.DataTransferRequest) (datatransfer.Voucher, error) { - vtypStr := incoming.VoucherType() - vouch, err := receiver.voucherFromRequest(incoming) + vtypStr := datatransfer.TypeIdentifier(incoming.VoucherType()) + decoder, has := receiver.impl.validatedTypes.Decoder(vtypStr) + if !has { + return nil, xerrors.Errorf("unknown voucher type: %s", vtypStr) + } + encodable, err := incoming.Voucher(decoder) if err != nil { - return vouch, err + return nil, err } + vouch := encodable.(datatransfer.Registerable) var validatorFunc func(peer.ID, datatransfer.Voucher, cid.Cid, ipld.Node) error + processor, _ := receiver.impl.validatedTypes.Processor(vtypStr) + validator := processor.(datatransfer.RequestValidator) if incoming.IsPull() { - validatorFunc = receiver.impl.validatedTypes[vtypStr].validator.ValidatePull + validatorFunc = validator.ValidatePull } else { - validatorFunc = receiver.impl.validatedTypes[vtypStr].validator.ValidatePush + validatorFunc = validator.ValidatePush } - stor, err := nodeFromBytes(incoming.Selector()) + stor, err := incoming.Selector() if err != nil { return vouch, err } @@ -86,30 +92,6 @@ func (receiver *graphsyncReceiver) validateVoucher(sender peer.ID, incoming mess return vouch, nil } -// voucherFromRequest takes an incoming request and attempts to create a -// voucher struct from it using the registered validated types. It returns -// a deserialized voucher and any error. It returns error if: -// * the voucher type has no validator registered -// * the voucher cannot be instantiated via reflection -// * request voucher bytes cannot be deserialized via .FromBytes() -func (receiver *graphsyncReceiver) voucherFromRequest(incoming message.DataTransferRequest) (datatransfer.Voucher, error) { - vtypStr := incoming.VoucherType() - - validatedType, ok := receiver.impl.validatedTypes[vtypStr] - if !ok { - return nil, fmt.Errorf("unregistered voucher type %s", vtypStr) - } - vStructVal := reflect.New(validatedType.voucherType.Elem()) - voucher, ok := vStructVal.Interface().(datatransfer.Voucher) - if !ok || reflect.ValueOf(voucher).IsNil() { - return nil, fmt.Errorf("problem instantiating type %s, voucher: %v", vtypStr, voucher) - } - if err := voucher.FromBytes(incoming.Voucher()); err != nil { - return voucher, err - } - return voucher, nil -} - // ReceiveResponse handles responses to our Push or Pull data transfer request. // It schedules a graphsync transfer only if our Pull Request is accepted. func (receiver *graphsyncReceiver) ReceiveResponse( diff --git a/impl/graphsync/graphsync_receiver_test.go b/impl/graphsync/graphsync_receiver_test.go index c0007b1f..58d72635 100644 --- a/impl/graphsync/graphsync_receiver_test.go +++ b/impl/graphsync/graphsync_receiver_test.go @@ -1,14 +1,11 @@ package graphsyncimpl_test import ( - "bytes" "context" "math/rand" - "reflect" "testing" "time" - "github.com/ipld/go-ipld-prime/codec/dagcbor" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -38,11 +35,8 @@ func TestSendResponseToIncomingRequest(t *testing.T) { gs2 := testutil.NewFakeGraphSync() - voucher := fakeDTType{"applesauce"} + voucher := testutil.NewFakeDTType() baseCid := testutil.GenerateCids(1)[0] - var buffer bytes.Buffer - err := dagcbor.Encoder(gsData.AllSelector, &buffer) - require.NoError(t, err) t.Run("Response to push with successful validation", func(t *testing.T) { id := datatransfer.TransferID(rand.Int31()) @@ -50,13 +44,13 @@ func TestSendResponseToIncomingRequest(t *testing.T) { sv.expectSuccessPush() dt := NewGraphSyncDataTransfer(host2, gs2, gsData.StoredCounter2) - require.NoError(t, dt.RegisterVoucherType(reflect.TypeOf(&fakeDTType{}), sv)) + require.NoError(t, dt.RegisterVoucherType(&testutil.FakeDTType{}, sv)) isPull := false - voucherBytes, err := voucher.ToBytes() + _, err := message.NewRequest(id, isPull, voucher.Type(), voucher, baseCid, gsData.AllSelector) + require.NoError(t, err) + request, err := message.NewRequest(id, isPull, voucher.Type(), voucher, baseCid, gsData.AllSelector) require.NoError(t, err) - _ = message.NewRequest(id, isPull, voucher.Type(), voucherBytes, baseCid, buffer.Bytes()) - request := message.NewRequest(id, isPull, voucher.Type(), voucherBytes, baseCid, buffer.Bytes()) require.NoError(t, dtnet1.SendMessage(ctx, host2.ID(), request)) var messageReceived receivedMessage select { @@ -85,14 +79,13 @@ func TestSendResponseToIncomingRequest(t *testing.T) { sv := newSV() sv.expectErrorPush() dt := NewGraphSyncDataTransfer(host2, gs2, gsData.StoredCounter2) - err = dt.RegisterVoucherType(reflect.TypeOf(&fakeDTType{}), sv) + err := dt.RegisterVoucherType(&testutil.FakeDTType{}, sv) require.NoError(t, err) isPull := false - voucherBytes, err := voucher.ToBytes() + request, err := message.NewRequest(id, isPull, voucher.Type(), voucher, baseCid, gsData.AllSelector) require.NoError(t, err) - request := message.NewRequest(id, isPull, voucher.Type(), voucherBytes, baseCid, buffer.Bytes()) require.NoError(t, dtnet1.SendMessage(ctx, host2.ID(), request)) var messageReceived receivedMessage @@ -122,15 +115,13 @@ func TestSendResponseToIncomingRequest(t *testing.T) { sv.expectSuccessPull() dt := NewGraphSyncDataTransfer(host2, gs2, gsData.StoredCounter2) - err = dt.RegisterVoucherType(reflect.TypeOf(&fakeDTType{}), sv) + err := dt.RegisterVoucherType(&testutil.FakeDTType{}, sv) require.NoError(t, err) isPull := true - voucherBytes, err := voucher.ToBytes() + request, err := message.NewRequest(id, isPull, voucher.Type(), voucher, baseCid, gsData.AllSelector) require.NoError(t, err) - request := message.NewRequest(id, isPull, voucher.Type(), voucherBytes, baseCid, buffer.Bytes()) - require.NoError(t, dtnet1.SendMessage(ctx, host2.ID(), request)) var messageReceived receivedMessage select { @@ -159,13 +150,13 @@ func TestSendResponseToIncomingRequest(t *testing.T) { sv.expectErrorPull() dt := NewGraphSyncDataTransfer(host2, gs2, gsData.StoredCounter2) - err = dt.RegisterVoucherType(reflect.TypeOf(&fakeDTType{}), sv) + err := dt.RegisterVoucherType(&testutil.FakeDTType{}, sv) require.NoError(t, err) isPull := true - voucherBytes, err := voucher.ToBytes() + + request, err := message.NewRequest(id, isPull, voucher.Type(), voucher, baseCid, gsData.AllSelector) require.NoError(t, err) - request := message.NewRequest(id, isPull, voucher.Type(), voucherBytes, baseCid, buffer.Bytes()) require.NoError(t, dtnet1.SendMessage(ctx, host2.ID(), request)) var messageReceived receivedMessage diff --git a/impl/graphsync/graphsync_impl_cbor_gen.go b/impl/graphsync/gsextension_cbor_gen.go similarity index 85% rename from impl/graphsync/graphsync_impl_cbor_gen.go rename to impl/graphsync/gsextension_cbor_gen.go index f61acf15..e78a773c 100644 --- a/impl/graphsync/graphsync_impl_cbor_gen.go +++ b/impl/graphsync/gsextension_cbor_gen.go @@ -1,3 +1,5 @@ +// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. + package graphsyncimpl import ( @@ -9,8 +11,6 @@ import ( xerrors "golang.org/x/xerrors" ) -// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. - var _ = xerrors.Errorf func (t *ExtensionDataTransferData) MarshalCBOR(w io.Writer) error { @@ -23,11 +23,16 @@ func (t *ExtensionDataTransferData) MarshalCBOR(w io.Writer) error { } // t.TransferID (uint64) (uint64) + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajUnsignedInt, uint64(t.TransferID))); err != nil { return err } // t.Initiator (peer.ID) (string) + if len(t.Initiator) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.Initiator was too long") + } + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len(t.Initiator)))); err != nil { return err } @@ -59,14 +64,18 @@ func (t *ExtensionDataTransferData) UnmarshalCBOR(r io.Reader) error { // t.TransferID (uint64) (uint64) - maj, extra, err = cbg.CborReadHeader(br) - if err != nil { - return err - } - if maj != cbg.MajUnsignedInt { - return fmt.Errorf("wrong type for uint64 field") + { + + maj, extra, err = cbg.CborReadHeader(br) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.TransferID = uint64(extra) + } - t.TransferID = uint64(extra) // t.Initiator (peer.ID) (string) { diff --git a/impl/graphsync/utils.go b/impl/graphsync/utils.go deleted file mode 100644 index babebbf1..00000000 --- a/impl/graphsync/utils.go +++ /dev/null @@ -1,33 +0,0 @@ -package graphsyncimpl - -import ( - "bytes" - - logging "github.com/ipfs/go-log" - "github.com/ipld/go-ipld-prime" - "github.com/ipld/go-ipld-prime/codec/dagcbor" - basicnode "github.com/ipld/go-ipld-prime/node/basic" -) - -var log = logging.Logger("graphsync-impl") - -// nodeAsBytes serializes an ipld.Node -func nodeAsBytes(node ipld.Node) ([]byte, error) { - var buffer bytes.Buffer - err := dagcbor.Encoder(node, &buffer) - if err != nil { - return nil, err - } - return buffer.Bytes(), nil -} - -// nodeFromBytes deserializes an ipld.Node -func nodeFromBytes(from []byte) (ipld.Node, error) { - reader := bytes.NewReader(from) - nb := basicnode.Style.Any.NewBuilder() - err := dagcbor.Decoder(nb, reader) - if err != nil { - return nil, err - } - return nb.Build(), err -} diff --git a/message/message.go b/message/message.go index 3475d814..d9b07dde 100644 --- a/message/message.go +++ b/message/message.go @@ -4,9 +4,12 @@ import ( "io" "github.com/ipfs/go-cid" + "github.com/ipld/go-ipld-prime" cborgen "github.com/whyrusleeping/cbor-gen" + xerrors "golang.org/x/xerrors" - "github.com/filecoin-project/go-data-transfer" + datatransfer "github.com/filecoin-project/go-data-transfer" + "github.com/filecoin-project/go-data-transfer/encoding" ) // Reference file: https://github.com/ipfs/go-graphsync/blob/master/message/message.go @@ -27,10 +30,10 @@ type DataTransferMessage interface { type DataTransferRequest interface { DataTransferMessage IsPull() bool - VoucherType() string - Voucher() []byte + VoucherType() datatransfer.TypeIdentifier + Voucher(decoder encoding.Decoder) (encoding.Encodable, error) BaseCid() cid.Cid - Selector() []byte + Selector() (ipld.Node, error) IsCancel() bool } @@ -41,15 +44,26 @@ type DataTransferResponse interface { } // NewRequest generates a new request for the data transfer protocol -func NewRequest(id datatransfer.TransferID, isPull bool, voucherIdentifier string, voucher []byte, baseCid cid.Cid, selector []byte) DataTransferRequest { +func NewRequest(id datatransfer.TransferID, isPull bool, vtype datatransfer.TypeIdentifier, voucher encoding.Encodable, baseCid cid.Cid, selector ipld.Node) (DataTransferRequest, error) { + vbytes, err := encoding.Encode(voucher) + if err != nil { + return nil, xerrors.Errorf("Creating request: %w", err) + } + if baseCid == cid.Undef { + return nil, xerrors.Errorf("base CID must be defined") + } + selBytes, err := encoding.Encode(selector) + if err != nil { + return nil, xerrors.Errorf("Error encoding selector") + } return &transferRequest{ Pull: isPull, - Vouch: voucher, - Stor: selector, - BCid: baseCid.String(), - VTyp: voucherIdentifier, + Vouch: &cborgen.Deferred{Raw: vbytes}, + Stor: &cborgen.Deferred{Raw: selBytes}, + BCid: &baseCid, + VTyp: vtype, XferID: uint64(id), - } + }, nil } // CancelRequest request generates a request to cancel an in progress request diff --git a/message/message_test.go b/message/message_test.go index 89a10159..2f023f3a 100644 --- a/message/message_test.go +++ b/message/message_test.go @@ -5,32 +5,33 @@ import ( "math/rand" "testing" + basicnode "github.com/ipld/go-ipld-prime/node/basic" + "github.com/ipld/go-ipld-prime/traversal/selector/builder" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/filecoin-project/go-data-transfer" + datatransfer "github.com/filecoin-project/go-data-transfer" . "github.com/filecoin-project/go-data-transfer/message" "github.com/filecoin-project/go-data-transfer/testutil" ) func TestNewRequest(t *testing.T) { baseCid := testutil.GenerateCids(1)[0] - selector := testutil.RandomBytes(100) + selector := builder.NewSelectorSpecBuilder(basicnode.Style.Any).Matcher().Node() isPull := true id := datatransfer.TransferID(rand.Int31()) - vtype := "FakeVoucherType" - voucher := testutil.RandomBytes(100) - - request := NewRequest(id, isPull, vtype, voucher, baseCid, selector) + voucher := testutil.NewFakeDTType() + request, err := NewRequest(id, isPull, voucher.Type(), voucher, baseCid, selector) + require.NoError(t, err) assert.Equal(t, id, request.TransferID()) assert.False(t, request.IsCancel()) assert.True(t, request.IsPull()) assert.True(t, request.IsRequest()) assert.Equal(t, baseCid.String(), request.BaseCid().String()) - assert.Equal(t, vtype, request.VoucherType()) - assert.Equal(t, voucher, request.Voucher()) - assert.Equal(t, selector, request.Selector()) - + testutil.AssertFakeDTVoucher(t, request, voucher) + receivedSelector, err := request.Selector() + require.NoError(t, err) + require.Equal(t, selector, receivedSelector) // Sanity check to make sure we can cast to DataTransferMessage msg, ok := request.(DataTransferMessage) require.True(t, ok) @@ -40,13 +41,15 @@ func TestNewRequest(t *testing.T) { } func TestTransferRequest_MarshalCBOR(t *testing.T) { // sanity check MarshalCBOR does its thing w/o error - req := NewTestTransferRequest() + req, err := NewTestTransferRequest() + require.NoError(t, err) wbuf := new(bytes.Buffer) require.NoError(t, req.MarshalCBOR(wbuf)) assert.Greater(t, wbuf.Len(), 0) } func TestTransferRequest_UnmarshalCBOR(t *testing.T) { - req := NewTestTransferRequest() + req, err := NewTestTransferRequest() + require.NoError(t, err) wbuf := new(bytes.Buffer) // use ToNet / FromNet require.NoError(t, req.ToNet(wbuf)) @@ -62,9 +65,8 @@ func TestTransferRequest_UnmarshalCBOR(t *testing.T) { assert.Equal(t, req.IsPull(), desReq.IsPull()) assert.Equal(t, req.IsCancel(), desReq.IsCancel()) assert.Equal(t, req.BaseCid(), desReq.BaseCid()) - assert.Equal(t, req.VoucherType(), desReq.VoucherType()) - assert.Equal(t, req.Voucher(), desReq.Voucher()) - assert.Equal(t, req.Selector(), desReq.Selector()) + testutil.AssertEqualFakeDTVoucher(t, req, desReq) + testutil.AssertEqualSelector(t, req, desReq) } func TestResponses(t *testing.T) { @@ -132,15 +134,15 @@ func TestRequestCancel(t *testing.T) { func TestToNetFromNetEquivalency(t *testing.T) { baseCid := testutil.GenerateCids(1)[0] - selector := testutil.RandomBytes(100) + selector := builder.NewSelectorSpecBuilder(basicnode.Style.Any).Matcher().Node() isPull := false id := datatransfer.TransferID(rand.Int31()) accepted := false - voucherType := "FakeVoucherType" - voucher := testutil.RandomBytes(100) - request := NewRequest(id, isPull, voucherType, voucher, baseCid, selector) + voucher := testutil.NewFakeDTType() + request, err := NewRequest(id, isPull, voucher.Type(), voucher, baseCid, selector) + require.NoError(t, err) buf := new(bytes.Buffer) - err := request.ToNet(buf) + err = request.ToNet(buf) require.NoError(t, err) require.Greater(t, buf.Len(), 0) deserialized, err := FromNet(buf) @@ -154,9 +156,8 @@ func TestToNetFromNetEquivalency(t *testing.T) { require.Equal(t, deserializedRequest.IsPull(), request.IsPull()) require.Equal(t, deserializedRequest.IsRequest(), request.IsRequest()) require.Equal(t, deserializedRequest.BaseCid(), request.BaseCid()) - require.Equal(t, deserializedRequest.VoucherType(), request.VoucherType()) - require.Equal(t, deserializedRequest.Voucher(), request.Voucher()) - require.Equal(t, deserializedRequest.Selector(), request.Selector()) + testutil.AssertEqualFakeDTVoucher(t, request, deserializedRequest) + testutil.AssertEqualSelector(t, request, deserializedRequest) response := NewResponse(id, accepted) err = response.ToNet(buf) @@ -185,12 +186,11 @@ func TestToNetFromNetEquivalency(t *testing.T) { require.Equal(t, deserializedRequest.IsRequest(), request.IsRequest()) } -func NewTestTransferRequest() DataTransferRequest { +func NewTestTransferRequest() (DataTransferRequest, error) { bcid := testutil.GenerateCids(1)[0] - selector := testutil.RandomBytes(100) + selector := builder.NewSelectorSpecBuilder(basicnode.Style.Any).Matcher().Node() isPull := false id := datatransfer.TransferID(rand.Int31()) - vtype := "FakeVoucherType" - v := testutil.RandomBytes(100) - return NewRequest(id, isPull, vtype, v, bcid, selector) + voucher := testutil.NewFakeDTType() + return NewRequest(id, isPull, voucher.Type(), voucher, bcid, selector) } diff --git a/message/transfer_message_cbor_gen.go b/message/transfer_message_cbor_gen.go index f53464e8..c49b260d 100644 --- a/message/transfer_message_cbor_gen.go +++ b/message/transfer_message_cbor_gen.go @@ -1,3 +1,5 @@ +// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. + package message import ( @@ -8,8 +10,6 @@ import ( xerrors "golang.org/x/xerrors" ) -// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. - var _ = xerrors.Errorf func (t *transferMessage) MarshalCBOR(w io.Writer) error { @@ -86,7 +86,7 @@ func (t *transferMessage) UnmarshalCBOR(r io.Reader) error { } else { t.Request = new(transferRequest) if err := t.Request.UnmarshalCBOR(br); err != nil { - return err + return xerrors.Errorf("unmarshaling t.Request pointer: %w", err) } } @@ -107,7 +107,7 @@ func (t *transferMessage) UnmarshalCBOR(r io.Reader) error { } else { t.Response = new(transferResponse) if err := t.Response.UnmarshalCBOR(br); err != nil { - return err + return xerrors.Errorf("unmarshaling t.Response pointer: %w", err) } } diff --git a/message/transfer_request.go b/message/transfer_request.go index aa04ac86..765e6fc1 100644 --- a/message/transfer_request.go +++ b/message/transfer_request.go @@ -1,10 +1,17 @@ package message import ( + "bytes" "io" - "github.com/filecoin-project/go-data-transfer" + datatransfer "github.com/filecoin-project/go-data-transfer" + "github.com/filecoin-project/go-data-transfer/encoding" "github.com/ipfs/go-cid" + "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/codec/dagcbor" + basicnode "github.com/ipld/go-ipld-prime/node/basic" + cbg "github.com/whyrusleeping/cbor-gen" + xerrors "golang.org/x/xerrors" ) //go:generate cbor-gen-for transferRequest @@ -12,14 +19,13 @@ import ( // transferRequest is a struct that fulfills the DataTransferRequest interface. // its members are exported to be used by cbor-gen type transferRequest struct { - BCid string + BCid *cid.Cid Canc bool - PID []byte Part bool Pull bool - Stor []byte - Vouch []byte - VTyp string + Stor *cbg.Deferred + Vouch *cbg.Deferred + VTyp datatransfer.TypeIdentifier XferID uint64 } @@ -39,27 +45,38 @@ func (trq *transferRequest) IsPull() bool { } // VoucherType returns the Voucher ID -func (trq *transferRequest) VoucherType() string { +func (trq *transferRequest) VoucherType() datatransfer.TypeIdentifier { return trq.VTyp } // Voucher returns the Voucher bytes -func (trq *transferRequest) Voucher() []byte { - return trq.Vouch +func (trq *transferRequest) Voucher(decoder encoding.Decoder) (encoding.Encodable, error) { + if trq.Vouch == nil { + return nil, xerrors.New("No voucher present to read") + } + return decoder.DecodeFromCbor(trq.Vouch.Raw) } // BaseCid returns the Base CID func (trq *transferRequest) BaseCid() cid.Cid { - res, err := cid.Decode(trq.BCid) - if err != nil { + if trq.BCid == nil { return cid.Undef } - return res + return *trq.BCid } // Selector returns the message Selector bytes -func (trq *transferRequest) Selector() []byte { - return trq.Stor +func (trq *transferRequest) Selector() (ipld.Node, error) { + if trq.Stor == nil { + return nil, xerrors.New("No selector present to read") + } + builder := basicnode.Style.Any.NewBuilder() + reader := bytes.NewReader(trq.Stor.Raw) + err := dagcbor.Decoder(builder, reader) + if err != nil { + return nil, xerrors.Errorf("Error decoding selector: %w", err) + } + return builder.Build(), nil } // IsCancel returns true if this is a cancel request diff --git a/message/transfer_request_cbor_gen.go b/message/transfer_request_cbor_gen.go index 31742b96..8407ee0f 100644 --- a/message/transfer_request_cbor_gen.go +++ b/message/transfer_request_cbor_gen.go @@ -1,15 +1,16 @@ +// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. + package message import ( "fmt" "io" + datatransfer "github.com/filecoin-project/go-data-transfer" cbg "github.com/whyrusleeping/cbor-gen" xerrors "golang.org/x/xerrors" ) -// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. - var _ = xerrors.Errorf func (t *transferRequest) MarshalCBOR(w io.Writer) error { @@ -17,16 +18,20 @@ func (t *transferRequest) MarshalCBOR(w io.Writer) error { _, err := w.Write(cbg.CborNull) return err } - if _, err := w.Write([]byte{137}); err != nil { + if _, err := w.Write([]byte{136}); err != nil { return err } - // t.BCid (string) (string) - if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len(t.BCid)))); err != nil { - return err - } - if _, err := w.Write([]byte(t.BCid)); err != nil { - return err + // t.BCid (cid.Cid) (struct) + + if t.BCid == nil { + if _, err := w.Write(cbg.CborNull); err != nil { + return err + } + } else { + if err := cbg.WriteCid(w, *t.BCid); err != nil { + return xerrors.Errorf("failed to write cid field t.BCid: %w", err) + } } // t.Canc (bool) (bool) @@ -34,14 +39,6 @@ func (t *transferRequest) MarshalCBOR(w io.Writer) error { return err } - // t.PID ([]uint8) (slice) - if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajByteString, uint64(len(t.PID)))); err != nil { - return err - } - if _, err := w.Write(t.PID); err != nil { - return err - } - // t.Part (bool) (bool) if err := cbg.WriteBool(w, t.Part); err != nil { return err @@ -52,23 +49,21 @@ func (t *transferRequest) MarshalCBOR(w io.Writer) error { return err } - // t.Stor ([]uint8) (slice) - if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajByteString, uint64(len(t.Stor)))); err != nil { - return err - } - if _, err := w.Write(t.Stor); err != nil { + // t.Stor (typegen.Deferred) (struct) + if err := t.Stor.MarshalCBOR(w); err != nil { return err } - // t.Vouch ([]uint8) (slice) - if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajByteString, uint64(len(t.Vouch)))); err != nil { + // t.Vouch (typegen.Deferred) (struct) + if err := t.Vouch.MarshalCBOR(w); err != nil { return err } - if _, err := w.Write(t.Vouch); err != nil { - return err + + // t.VTyp (datatransfer.TypeIdentifier) (string) + if len(t.VTyp) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.VTyp was too long") } - // t.VTyp (string) (string) if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len(t.VTyp)))); err != nil { return err } @@ -77,9 +72,11 @@ func (t *transferRequest) MarshalCBOR(w io.Writer) error { } // t.XferID (uint64) (uint64) + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajUnsignedInt, uint64(t.XferID))); err != nil { return err } + return nil } @@ -94,19 +91,33 @@ func (t *transferRequest) UnmarshalCBOR(r io.Reader) error { return fmt.Errorf("cbor input should be of type array") } - if extra != 9 { + if extra != 8 { return fmt.Errorf("cbor input had wrong number of fields") } - // t.BCid (string) (string) + // t.BCid (cid.Cid) (struct) { - sval, err := cbg.ReadString(br) + + pb, err := br.PeekByte() if err != nil { return err } + if pb == cbg.CborNull[0] { + var nbuf [1]byte + if _, err := br.Read(nbuf[:]); err != nil { + return err + } + } else { + + c, err := cbg.ReadCid(br) + if err != nil { + return xerrors.Errorf("failed to read cid field t.BCid: %w", err) + } + + t.BCid = &c + } - t.BCid = string(sval) } // t.Canc (bool) (bool) @@ -125,23 +136,6 @@ func (t *transferRequest) UnmarshalCBOR(r io.Reader) error { default: return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) } - // t.PID ([]uint8) (slice) - - maj, extra, err = cbg.CborReadHeader(br) - if err != nil { - return err - } - - if extra > cbg.ByteArrayMaxLen { - return fmt.Errorf("t.PID: byte array too large (%d)", extra) - } - if maj != cbg.MajByteString { - return fmt.Errorf("expected byte array") - } - t.PID = make([]byte, extra) - if _, err := io.ReadFull(br, t.PID); err != nil { - return err - } // t.Part (bool) (bool) maj, extra, err = cbg.CborReadHeader(br) @@ -176,41 +170,49 @@ func (t *transferRequest) UnmarshalCBOR(r io.Reader) error { default: return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) } - // t.Stor ([]uint8) (slice) + // t.Stor (typegen.Deferred) (struct) - maj, extra, err = cbg.CborReadHeader(br) - if err != nil { - return err - } + { - if extra > cbg.ByteArrayMaxLen { - return fmt.Errorf("t.Stor: byte array too large (%d)", extra) - } - if maj != cbg.MajByteString { - return fmt.Errorf("expected byte array") - } - t.Stor = make([]byte, extra) - if _, err := io.ReadFull(br, t.Stor); err != nil { - return err - } - // t.Vouch ([]uint8) (slice) + pb, err := br.PeekByte() + if err != nil { + return err + } + if pb == cbg.CborNull[0] { + var nbuf [1]byte + if _, err := br.Read(nbuf[:]); err != nil { + return err + } + } else { + t.Stor = new(cbg.Deferred) + if err := t.Stor.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("unmarshaling t.Stor pointer: %w", err) + } + } - maj, extra, err = cbg.CborReadHeader(br) - if err != nil { - return err } + // t.Vouch (typegen.Deferred) (struct) + + { + + pb, err := br.PeekByte() + if err != nil { + return err + } + if pb == cbg.CborNull[0] { + var nbuf [1]byte + if _, err := br.Read(nbuf[:]); err != nil { + return err + } + } else { + t.Vouch = new(cbg.Deferred) + if err := t.Vouch.UnmarshalCBOR(br); err != nil { + return xerrors.Errorf("unmarshaling t.Vouch pointer: %w", err) + } + } - if extra > cbg.ByteArrayMaxLen { - return fmt.Errorf("t.Vouch: byte array too large (%d)", extra) - } - if maj != cbg.MajByteString { - return fmt.Errorf("expected byte array") - } - t.Vouch = make([]byte, extra) - if _, err := io.ReadFull(br, t.Vouch); err != nil { - return err } - // t.VTyp (string) (string) + // t.VTyp (datatransfer.TypeIdentifier) (string) { sval, err := cbg.ReadString(br) @@ -218,17 +220,21 @@ func (t *transferRequest) UnmarshalCBOR(r io.Reader) error { return err } - t.VTyp = string(sval) + t.VTyp = datatransfer.TypeIdentifier(sval) } // t.XferID (uint64) (uint64) - maj, extra, err = cbg.CborReadHeader(br) - if err != nil { - return err - } - if maj != cbg.MajUnsignedInt { - return fmt.Errorf("wrong type for uint64 field") + { + + maj, extra, err = cbg.CborReadHeader(br) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.XferID = uint64(extra) + } - t.XferID = uint64(extra) return nil } diff --git a/message/transfer_response_cbor_gen.go b/message/transfer_response_cbor_gen.go index 04f2db7e..3947af91 100644 --- a/message/transfer_response_cbor_gen.go +++ b/message/transfer_response_cbor_gen.go @@ -1,3 +1,5 @@ +// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. + package message import ( @@ -8,8 +10,6 @@ import ( xerrors "golang.org/x/xerrors" ) -// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. - var _ = xerrors.Errorf func (t *transferResponse) MarshalCBOR(w io.Writer) error { @@ -27,9 +27,11 @@ func (t *transferResponse) MarshalCBOR(w io.Writer) error { } // t.XferID (uint64) (uint64) + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajUnsignedInt, uint64(t.XferID))); err != nil { return err } + return nil } @@ -67,13 +69,17 @@ func (t *transferResponse) UnmarshalCBOR(r io.Reader) error { } // t.XferID (uint64) (uint64) - maj, extra, err = cbg.CborReadHeader(br) - if err != nil { - return err - } - if maj != cbg.MajUnsignedInt { - return fmt.Errorf("wrong type for uint64 field") + { + + maj, extra, err = cbg.CborReadHeader(br) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.XferID = uint64(extra) + } - t.XferID = uint64(extra) return nil } diff --git a/network/libp2p_impl_test.go b/network/libp2p_impl_test.go index a44d6a15..9c78f88e 100644 --- a/network/libp2p_impl_test.go +++ b/network/libp2p_impl_test.go @@ -6,12 +6,14 @@ import ( "testing" "time" + basicnode "github.com/ipld/go-ipld-prime/node/basic" + "github.com/ipld/go-ipld-prime/traversal/selector/builder" "github.com/libp2p/go-libp2p-core/peer" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/filecoin-project/go-data-transfer" + datatransfer "github.com/filecoin-project/go-data-transfer" "github.com/filecoin-project/go-data-transfer/message" "github.com/filecoin-project/go-data-transfer/network" "github.com/filecoin-project/go-data-transfer/testutil" @@ -81,12 +83,12 @@ func TestMessageSendAndReceive(t *testing.T) { t.Run("Send Request", func(t *testing.T) { baseCid := testutil.GenerateCids(1)[0] - selector := testutil.RandomBytes(100) + selector := builder.NewSelectorSpecBuilder(basicnode.Style.Any).Matcher().Node() isPull := false id := datatransfer.TransferID(rand.Int31()) - vType := "FakeVoucherType" - voucher := testutil.RandomBytes(100) - request := message.NewRequest(id, isPull, vType, voucher, baseCid, selector) + voucher := testutil.NewFakeDTType() + request, err := message.NewRequest(id, isPull, voucher.Type(), voucher, baseCid, selector) + require.NoError(t, err) require.NoError(t, dtnet1.SendMessage(ctx, host2.ID(), request)) select { @@ -106,9 +108,8 @@ func TestMessageSendAndReceive(t *testing.T) { assert.Equal(t, request.IsPull(), receivedRequest.IsPull()) assert.Equal(t, request.IsRequest(), receivedRequest.IsRequest()) assert.True(t, receivedRequest.BaseCid().Equals(request.BaseCid())) - assert.Equal(t, request.VoucherType(), receivedRequest.VoucherType()) - assert.Equal(t, request.Voucher(), receivedRequest.Voucher()) - assert.Equal(t, request.Selector(), receivedRequest.Selector()) + testutil.AssertEqualFakeDTVoucher(t, request, receivedRequest) + testutil.AssertEqualSelector(t, request, receivedRequest) }) t.Run("Send Response", func(t *testing.T) { diff --git a/registry/registry.go b/registry/registry.go new file mode 100644 index 00000000..4dc03707 --- /dev/null +++ b/registry/registry.go @@ -0,0 +1,68 @@ +package registry + +import ( + "sync" + + datatransfer "github.com/filecoin-project/go-data-transfer" + "github.com/filecoin-project/go-data-transfer/encoding" + "golang.org/x/xerrors" +) + +// Processor is an interface that processes a certain type of encodable objects +// in a registry. The actual specifics of the interface that must be satisfied are +// left to the user of the registry +type Processor interface{} + +type registryEntry struct { + decoder encoding.Decoder + processor Processor +} + +// Registry maintans a register of types of encodable objects and a corresponding +// processor for those objects +// The encodable types must have a method Type() that specifies and identifier +// so they correct decoding function and processor can be identified based +// on this unique identifier +type Registry struct { + registryLk sync.RWMutex + entries map[datatransfer.TypeIdentifier]registryEntry +} + +// NewRegistry initialzes a new registy +func NewRegistry() *Registry { + return &Registry{ + entries: make(map[datatransfer.TypeIdentifier]registryEntry), + } +} + +// Register registers the given processor for the given entry type +func (r *Registry) Register(entry datatransfer.Registerable, processor Processor) error { + identifier := entry.Type() + decoder, err := encoding.NewDecoder(entry) + if err != nil { + return xerrors.Errorf("registering entry type %s: %w", identifier, err) + } + r.registryLk.Lock() + defer r.registryLk.Unlock() + if _, ok := r.entries[identifier]; ok { + return xerrors.Errorf("identifier already registered: %s", identifier) + } + r.entries[identifier] = registryEntry{decoder, processor} + return nil +} + +// Decoder gets a decoder for the given identifier +func (r *Registry) Decoder(identifier datatransfer.TypeIdentifier) (encoding.Decoder, bool) { + r.registryLk.RLock() + entry, has := r.entries[identifier] + r.registryLk.RUnlock() + return entry.decoder, has +} + +// Processor gets the processing interface for the given identifer +func (r *Registry) Processor(identifier datatransfer.TypeIdentifier) (Processor, bool) { + r.registryLk.RLock() + entry, has := r.entries[identifier] + r.registryLk.RUnlock() + return entry.processor, has +} diff --git a/registry/registry_test.go b/registry/registry_test.go new file mode 100644 index 00000000..87fbf4a7 --- /dev/null +++ b/registry/registry_test.go @@ -0,0 +1,42 @@ +package registry_test + +import ( + "testing" + + "github.com/filecoin-project/go-data-transfer/registry" + "github.com/filecoin-project/go-data-transfer/testutil" + "github.com/stretchr/testify/require" +) + +func TestRegistry(t *testing.T) { + r := registry.NewRegistry() + t.Run("it registers", func(t *testing.T) { + err := r.Register(&testutil.FakeDTType{}, func() {}) + require.NoError(t, err) + }) + t.Run("it errors when registred again", func(t *testing.T) { + err := r.Register(&testutil.FakeDTType{}, func() {}) + require.EqualError(t, err, "identifier already registered: FakeDTType") + }) + t.Run("it errors when decoder setup fails", func(t *testing.T) { + err := r.Register(testutil.FakeDTType{}, func() {}) + require.EqualError(t, err, "registering entry type FakeDTType: type must be a pointer") + }) + t.Run("it reads decoders", func(t *testing.T) { + decoder, has := r.Decoder("FakeDTType") + require.True(t, has) + require.NotNil(t, decoder) + decoder, has = r.Decoder("OtherType") + require.False(t, has) + require.Nil(t, decoder) + }) + t.Run("it reads processors", func(t *testing.T) { + processor, has := r.Processor("FakeDTType") + require.True(t, has) + require.NotNil(t, processor) + processor, has = r.Processor("OtherType") + require.False(t, has) + require.Nil(t, processor) + }) + +} diff --git a/testutil/fakedttype.go b/testutil/fakedttype.go new file mode 100644 index 00000000..19071e58 --- /dev/null +++ b/testutil/fakedttype.go @@ -0,0 +1,52 @@ +package testutil + +import ( + "testing" + + "github.com/stretchr/testify/require" + + datatransfer "github.com/filecoin-project/go-data-transfer" + "github.com/filecoin-project/go-data-transfer/encoding" + "github.com/filecoin-project/go-data-transfer/message" +) + +//go:generate cbor-gen-for FakeDTType + +// FakeDTType simple fake type for using with registries +type FakeDTType struct { + Data string +} + +// Type satisfies registry.Entry +func (ft FakeDTType) Type() datatransfer.TypeIdentifier { + return "FakeDTType" +} + +// AssertFakeDTVoucher asserts that a data transfer requests contains the expected fake data transfer voucher type +func AssertFakeDTVoucher(t *testing.T, request message.DataTransferRequest, expected *FakeDTType) { + require.Equal(t, datatransfer.TypeIdentifier("FakeDTType"), request.VoucherType()) + fakeDTDecoder, err := encoding.NewDecoder(&FakeDTType{}) + require.NoError(t, err) + decoded, err := request.Voucher(fakeDTDecoder) + require.NoError(t, err) + require.Equal(t, expected, decoded) +} + +// AssertEqualFakeDTVoucher asserts that two requests have the same fake data transfer voucher +func AssertEqualFakeDTVoucher(t *testing.T, expectedRequest message.DataTransferRequest, request message.DataTransferRequest) { + require.Equal(t, expectedRequest.VoucherType(), request.VoucherType()) + fakeDTDecoder, err := encoding.NewDecoder(&FakeDTType{}) + require.NoError(t, err) + expectedDecoded, err := request.Voucher(fakeDTDecoder) + require.NoError(t, err) + decoded, err := request.Voucher(fakeDTDecoder) + require.NoError(t, err) + require.Equal(t, expectedDecoded, decoded) +} + +// NewFakeDTType returns a fake dt type with random data +func NewFakeDTType() *FakeDTType { + return &FakeDTType{Data: string(RandomBytes(100))} +} + +var _ datatransfer.Registerable = &FakeDTType{} diff --git a/testutil/fakedttype_cbor_gen.go b/testutil/fakedttype_cbor_gen.go new file mode 100644 index 00000000..f9a35e37 --- /dev/null +++ b/testutil/fakedttype_cbor_gen.go @@ -0,0 +1,64 @@ +// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. + +package testutil + +import ( + "fmt" + "io" + + cbg "github.com/whyrusleeping/cbor-gen" + xerrors "golang.org/x/xerrors" +) + +var _ = xerrors.Errorf + +func (t *FakeDTType) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + if _, err := w.Write([]byte{129}); err != nil { + return err + } + + // t.Data (string) (string) + if len(t.Data) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.Data was too long") + } + + if _, err := w.Write(cbg.CborEncodeMajorType(cbg.MajTextString, uint64(len(t.Data)))); err != nil { + return err + } + if _, err := w.Write([]byte(t.Data)); err != nil { + return err + } + return nil +} + +func (t *FakeDTType) UnmarshalCBOR(r io.Reader) error { + br := cbg.GetPeeker(r) + + maj, extra, err := cbg.CborReadHeader(br) + if err != nil { + return err + } + if maj != cbg.MajArray { + return fmt.Errorf("cbor input should be of type array") + } + + if extra != 1 { + return fmt.Errorf("cbor input had wrong number of fields") + } + + // t.Data (string) (string) + + { + sval, err := cbg.ReadString(br) + if err != nil { + return err + } + + t.Data = string(sval) + } + return nil +} diff --git a/testutil/testutil.go b/testutil/testutil.go index da2af32d..da2d0f16 100644 --- a/testutil/testutil.go +++ b/testutil/testutil.go @@ -2,12 +2,15 @@ package testutil import ( "bytes" + "testing" + "github.com/filecoin-project/go-data-transfer/message" blocks "github.com/ipfs/go-block-format" "github.com/ipfs/go-cid" blocksutil "github.com/ipfs/go-ipfs-blocksutil" "github.com/jbenet/go-random" "github.com/libp2p/go-libp2p-core/peer" + "github.com/stretchr/testify/require" ) var blockGenerator = blocksutil.NewBlockGenerator() @@ -81,3 +84,12 @@ func IndexOf(blks []blocks.Block, c cid.Cid) int { func ContainsBlock(blks []blocks.Block, block blocks.Block) bool { return IndexOf(blks, block.Cid()) != -1 } + +// AssertEqualSelector asserts two requests have the same valid selector +func AssertEqualSelector(t *testing.T, expectedRequest message.DataTransferRequest, request message.DataTransferRequest) { + expectedSelector, err := expectedRequest.Selector() + require.NoError(t, err) + selector, err := request.Selector() + require.NoError(t, err) + require.Equal(t, expectedSelector, selector) +} diff --git a/types.go b/types.go index 776776b9..1c88a9e1 100644 --- a/types.go +++ b/types.go @@ -2,26 +2,31 @@ package datatransfer import ( "context" - "reflect" "time" + "github.com/filecoin-project/go-data-transfer/encoding" "github.com/ipfs/go-cid" "github.com/ipld/go-ipld-prime" "github.com/libp2p/go-libp2p-core/peer" ) +// TypeIdentifier is a unique string identifier for a type of encodable object in a +// registry +type TypeIdentifier string + +// Registerable is a type of object in a registry. It must be encodable and must +// have a single method that uniquely identifies its type +type Registerable interface { + encoding.Encodable + // Type is a unique string identifier for this voucher type + Type() TypeIdentifier +} + // Voucher is used to validate // a data transfer request against the underlying storage or retrieval deal // that precipitated it. The only requirement is a voucher can read and write // from bytes, and has a string identifier type -type Voucher interface { - // ToBytes converts the Voucher to raw bytes - ToBytes() ([]byte, error) - // FromBytes reads a Voucher from raw bytes - FromBytes([]byte) error - // Type is a unique string identifier for this voucher type - Type() string -} +type Voucher Registerable // Status is the status of transfer for a given channel type Status int @@ -110,6 +115,7 @@ type ChannelState struct { received uint64 } +// EmptyChannelState is the zero value for channel state, meaning not present var EmptyChannelState = ChannelState{} // Sent returns the number of bytes sent @@ -171,7 +177,7 @@ type Manager interface { // RegisterVoucherType registers a validator for the given voucher type // will error if voucher type does not implement voucher // or if there is a voucher type registered with an identical identifier - RegisterVoucherType(voucherType reflect.Type, validator RequestValidator) error + RegisterVoucherType(voucherType Voucher, validator RequestValidator) error // open a data transfer that will send data to the recipient peer and // transfer parts of the piece that match the selector