diff --git a/cid.go b/cid.go index ae3d1fd..f182424 100644 --- a/cid.go +++ b/cid.go @@ -717,6 +717,9 @@ func (r *bufByteReader) ReadByte() (byte, error) { // It's recommended to supply a reader that buffers and implements io.ByteReader, // as CidFromReader has to do many single-byte reads to decode varints. // If the argument only implements io.Reader, single-byte Read calls are used instead. +// +// If the Reader is found to yield zero bytes, an io.EOF error is returned directly, in all +// other error cases, an ErrInvalidCid, wrapping the original error, is returned. func CidFromReader(r io.Reader) (int, Cid, error) { // 64 bytes is enough for any CIDv0, // and it's enough for most CIDv1s in practice. @@ -727,6 +730,11 @@ func CidFromReader(r io.Reader) (int, Cid, error) { // The varint package wants a io.ByteReader, so we must wrap our io.Reader. vers, err := varint.ReadUvarint(br) if err != nil { + if err == io.EOF { + // First-byte read in ReadUvarint errors with io.EOF, so reader has no data. + // Subsequent reads with an EOF will return io.ErrUnexpectedEOF and be wrapped here. + return 0, Undef, err + } return len(br.dst), Undef, ErrInvalidCid{err} } diff --git a/cid_test.go b/cid_test.go index 31989da..36fba76 100644 --- a/cid_test.go +++ b/cid_test.go @@ -783,6 +783,30 @@ func TestBadCidInput(t *testing.T) { } } +func TestFromReaderNoData(t *testing.T) { + // Reading no data from io.Reader should return io.EOF, not ErrInvalidCid. + n, cid, err := CidFromReader(bytes.NewReader(nil)) + if err != io.EOF { + t.Fatal("Expected io.EOF error") + } + if cid != Undef { + t.Fatal("Expected Undef CID") + } + if n != 0 { + t.Fatal("Expected 0 data") + } + + // Read byte indicatiing more data to and check error is ErrInvalidCid. + _, _, err = CidFromReader(bytes.NewReader([]byte{0x80})) + if !errors.Is(err, ErrInvalidCid{}) { + t.Fatal("Expected ErrInvalidCid error") + } + // Check for expected wrapped error. + if !errors.Is(err, io.ErrUnexpectedEOF) { + t.Fatal("Expected error", io.ErrUnexpectedEOF) + } +} + func TestBadParse(t *testing.T) { hash, err := mh.Sum([]byte("foobar"), mh.SHA3_256, -1) if err != nil { diff --git a/version.json b/version.json index 372b6ea..26a7d47 100644 --- a/version.json +++ b/version.json @@ -1,3 +1,3 @@ { - "version": "v0.4.0" + "version": "v0.4.1" }