diff --git a/p2p/stream/protocols/sync/chain.go b/p2p/stream/protocols/sync/chain.go index 14e3be5f8e..a095fffc1f 100644 --- a/p2p/stream/protocols/sync/chain.go +++ b/p2p/stream/protocols/sync/chain.go @@ -168,8 +168,7 @@ func (ch *chainHelperImpl) getNodeData(hs []common.Hash) ([][]byte, error) { // getReceipts assembles the response to a receipt query. func (ch *chainHelperImpl) getReceipts(hs []common.Hash) ([]types.Receipts, error) { - var receipts []types.Receipts - + receipts := make([]types.Receipts, len(hs)) for i, hash := range hs { // Retrieve the requested block's receipts results := ch.chain.GetReceiptsByHash(hash) @@ -177,8 +176,9 @@ func (ch *chainHelperImpl) getReceipts(hs []common.Hash) ([]types.Receipts, erro if header := ch.chain.GetHeaderByHash(hash); header == nil || header.ReceiptHash() != types.EmptyRootHash { continue } + return nil, errors.New("invalid hashes to get receipts") } - receipts[i] = append(receipts[i], results...) + receipts[i] = results } return receipts, nil } diff --git a/p2p/stream/protocols/sync/chain_test.go b/p2p/stream/protocols/sync/chain_test.go index 8883d7cb5d..3f3f68b889 100644 --- a/p2p/stream/protocols/sync/chain_test.go +++ b/p2p/stream/protocols/sync/chain_test.go @@ -53,13 +53,28 @@ func (tch *testChainHelper) getNodeData(hs []common.Hash) ([][]byte, error) { func (tch *testChainHelper) getReceipts(hs []common.Hash) ([]types.Receipts, error) { testReceipts := makeTestReceipts(len(hs), 3) - receipts := make([]types.Receipts, len(hs)*3) + receipts := make([]types.Receipts, len(hs)) for i, _ := range hs { receipts[i] = testReceipts } return receipts, nil } +func checkGetReceiptsResult(b []byte, hs []common.Hash) error { + var msg = &syncpb.Message{} + if err := protobuf.Unmarshal(b, msg); err != nil { + return err + } + bhResp, err := msg.GetReceiptsResponse() + if err != nil { + return err + } + if len(hs) != len(bhResp.Receipts) { + return errors.New("unexpected size") + } + return nil +} + func numberToHash(bn uint64) common.Hash { var h common.Hash binary.LittleEndian.PutUint64(h[:], bn) diff --git a/p2p/stream/protocols/sync/message/parse.go b/p2p/stream/protocols/sync/message/parse.go index 22e6102204..b20b2f1f1b 100644 --- a/p2p/stream/protocols/sync/message/parse.go +++ b/p2p/stream/protocols/sync/message/parse.go @@ -79,3 +79,19 @@ func (msg *Message) GetBlocksByHashesResponse() (*GetBlocksByHashesResponse, err } return gbResp, nil } + +// GetReceiptsResponse parse the message to GetReceiptsResponse +func (msg *Message) GetReceiptsResponse() (*GetReceiptsResponse, error) { + resp := msg.GetResp() + if resp == nil { + return nil, errors.New("not response message") + } + if errResp := resp.GetErrorResponse(); errResp != nil { + return nil, &ResponseError{errResp.Error} + } + grResp := resp.GetGetReceiptsResponse() + if grResp == nil { + return nil, errors.New("not GetGetReceiptsResponse") + } + return grResp, nil +} diff --git a/p2p/stream/protocols/sync/stream_test.go b/p2p/stream/protocols/sync/stream_test.go index a9aae57faa..9f134ee133 100644 --- a/p2p/stream/protocols/sync/stream_test.go +++ b/p2p/stream/protocols/sync/stream_test.go @@ -40,6 +40,16 @@ var ( } testGetBlocksByHashesRequest = syncpb.MakeGetBlocksByHashesRequest(testGetBlockByHashes) testGetBlocksByHashesRequestMsg = syncpb.MakeMessageFromRequest(testGetBlocksByHashesRequest) + + testGetReceipts = []common.Hash{ + numberToHash(1), + numberToHash(2), + numberToHash(3), + numberToHash(4), + numberToHash(5), + } + testGetReceiptsRequest = syncpb.MakeGetReceiptsRequest(testGetReceipts) + testGetReceiptsRequestMsg = syncpb.MakeMessageFromRequest(testGetReceiptsRequest) ) func TestSyncStream_HandleGetBlocksByRequest(t *testing.T) { @@ -126,6 +136,27 @@ func TestSyncStream_HandleGetBlocksByHashes(t *testing.T) { } } +func TestSyncStream_HandleGetReceipts(t *testing.T) { + st, remoteSt := makeTestSyncStream() + + go st.run() + defer close(st.closeC) + + req := testGetReceiptsRequestMsg + b, _ := protobuf.Marshal(req) + err := remoteSt.WriteBytes(b) + if err != nil { + t.Fatal(err) + } + + time.Sleep(200 * time.Millisecond) + receivedBytes, _ := remoteSt.ReadBytes() + + if err := checkGetReceiptsResult(receivedBytes, testGetBlockByHashes); err != nil { + t.Fatal(err) + } +} + func makeTestSyncStream() (*syncStream, *testRemoteBaseStream) { localRaw, remoteRaw := makePairP2PStreams() remote := newTestRemoteBaseStream(remoteRaw)