diff --git a/cid.go b/cid.go index f0d3a8f..fba9227 100644 --- a/cid.go +++ b/cid.go @@ -680,3 +680,145 @@ func CidFromBytes(data []byte) (int, Cid, error) { return l, Cid{string(data[0:l])}, nil } + +func toBufByteReader(r io.Reader, dst []byte) *bufByteReader { + // If the reader already implements ByteReader, use it directly. + // Otherwise, use a fallback that does 1-byte Reads. + if br, ok := r.(io.ByteReader); ok { + return &bufByteReader{direct: br, dst: dst} + } + return &bufByteReader{fallback: r, dst: dst} +} + +type bufByteReader struct { + direct io.ByteReader + fallback io.Reader + + consumed int + dst []byte +} + +func (r *bufByteReader) ReadByte() (byte, error) { + // We still have some of the initial bytes to use. + if r.consumed < len(r.dst) { + b := r.dst[r.consumed] + r.consumed++ + return b, nil + } + r.consumed++ + + // The underlying reader has ReadByte; use it. + if br := r.direct; br != nil { + b, err := br.ReadByte() + if err != nil { + return 0, err + } + r.dst = append(r.dst, b) + return b, nil + } + + // Fall back to a one-byte Read. + var p [1]byte + if _, err := io.ReadFull(r.fallback, p[:]); err != nil { + return 0, err + } + r.dst = append(r.dst, p[0]) + return p[0], nil +} + +// CidFromReader reads a precise number of bytes for a CID from a given reader. +// It returns the number of bytes read, the CID, and any error encountered. +// The number of bytes read is accurate even if a non-nil error is returned. +// +// It's recommended to supply a reader that buffers and implements io.ByteReader, +// as CidFromReader has to do many single-byte reads to decode varints. +// If the argument only implements io.Reader, single-byte Read calls are used instead. +func CidFromReader(r io.Reader) (int, Cid, error) { + // 64 bytes is enough for any CIDv0, + // and it's enough for most CIDv1s in practice. + // If the digest is too long, we'll allocate more. + buf := make([]byte, 0, 64) + + // We read two bytes, to tell if this is a CIDv0 or a CIDv1. + if n, err := io.ReadFull(r, buf[:2]); err != nil { + return n, Undef, err + } + buf = buf[:2] + + // If we have a CIDv0, read the rest of the bytes and cast the buffer. + if buf[0] == mh.SHA2_256 && buf[1] == 32 { + if n, err := io.ReadFull(r, buf[2:34]); err != nil { + return len(buf) + n, Undef, err + } + + buf = buf[:34] + h, err := mh.Cast(buf) + if err != nil { + return len(buf), Undef, err + } + + return len(buf), Cid{string(h)}, nil + } + + // The varint package wants a io.ByteReader, so we must wrap our io.Reader. + // Note that we already read two bytes, so bufByteReader uses those first. + // After those two bytes, bufByteReader appends the read bytes to br.dst. + br := toBufByteReader(r, buf[:2]) + vers, err := varint.ReadUvarint(br) + if err != nil { + return len(br.dst), Undef, err + } + + if vers != 1 { + return len(br.dst), Undef, fmt.Errorf("expected 1 as the cid version number, got: %d", vers) + } + + // CID block encoding multicodec. + _, err = varint.ReadUvarint(br) + if err != nil { + return len(br.dst), Undef, err + } + + // We could replace most of the code below with go-multihash's ReadMultihash. + // Note that it would save code, but prevent reusing buffers. + // Plus, we already have a ByteReader now. + mhStart := len(br.dst) + + // Multihash hash function code. + _, err = varint.ReadUvarint(br) + if err != nil { + return len(br.dst), Undef, err + } + + // Multihash digest length. + mhl, err := varint.ReadUvarint(br) + if err != nil { + return len(br.dst), Undef, err + } + + // Update buf's length. + // We're not reading single bytes beyond this point. + buf = br.dst + br = nil + + // Multihash digest; might be too long, so allocate. + // Refuse to make large allocations to prevent OOMs due to bugs. + // TODO: reuse buf if it has enough space + const maxDigestAlloc = 32 << 20 // 32MiB + if mhl > maxDigestAlloc { + return len(buf), Undef, fmt.Errorf("refusing to allocate %d bytes for a digest", mhl) + } + digest := make([]byte, int(mhl)) + if n, err := io.ReadFull(r, digest); err != nil { + return len(buf) + n, Undef, err + } + buf = append(buf, digest...) + + // This simply ensures the multihash is valid. + _, _, err = mh.MHFromBytes(buf[mhStart:]) + if err != nil { + return len(buf), Undef, err + } + + return len(buf), Cid{string(buf)}, nil +} diff --git a/cid_test.go b/cid_test.go index 28fc964..7bcb75f 100644 --- a/cid_test.go +++ b/cid_test.go @@ -4,10 +4,12 @@ import ( "bytes" "encoding/json" "fmt" + "io" "math/rand" "reflect" "strings" "testing" + "testing/iotest" mbase "github.com/multiformats/go-multibase" mh "github.com/multiformats/go-multihash" @@ -692,6 +694,31 @@ func TestReadCidsFromBuffer(t *testing.T) { if cur != len(buf) { t.Fatal("had trailing bytes") } + + // The same, but now with CidFromReader. + // In multiple forms, to catch more io interface bugs. + for _, r := range []io.Reader{ + // implements io.ByteReader + bytes.NewReader(buf), + + // tiny reads, no io.ByteReader + iotest.OneByteReader(bytes.NewReader(buf)), + } { + cur = 0 + for _, expc := range cids { + n, c, err := CidFromReader(r) + if err != nil { + t.Fatal(err) + } + if c != expc { + t.Fatal("cids mismatched") + } + cur += n + } + if cur != len(buf) { + t.Fatal("had trailing bytes") + } + } } func TestBadCidFromBytes(t *testing.T) {