diff --git a/cid.go b/cid.go index f0d3a8f..6ff505c 100644 --- a/cid.go +++ b/cid.go @@ -680,3 +680,139 @@ func CidFromBytes(data []byte) (int, Cid, error) { return l, Cid{string(data[0:l])}, nil } + +func toBufByteReader(r io.Reader, dst []byte) *bufByteReader { + // If the reader already implements ByteReader, use it directly. + // Otherwise, use a fallback that does 1-byte Reads. + if br, ok := r.(io.ByteReader); ok { + return &bufByteReader{direct: br, dst: dst} + } + return &bufByteReader{fallback: r, dst: dst} +} + +type bufByteReader struct { + direct io.ByteReader + fallback io.Reader + + dst []byte +} + +func (r *bufByteReader) ReadByte() (byte, error) { + // The underlying reader has ReadByte; use it. + if br := r.direct; br != nil { + b, err := br.ReadByte() + if err != nil { + return 0, err + } + r.dst = append(r.dst, b) + return b, nil + } + + // Fall back to a one-byte Read. + // TODO: consider reading straight into dst, + // once we have benchmarks and if they prove that to be faster. + var p [1]byte + if _, err := io.ReadFull(r.fallback, p[:]); err != nil { + return 0, err + } + r.dst = append(r.dst, p[0]) + return p[0], nil +} + +// CidFromReader reads a precise number of bytes for a CID from a given reader. +// It returns the number of bytes read, the CID, and any error encountered. +// The number of bytes read is accurate even if a non-nil error is returned. +// +// It's recommended to supply a reader that buffers and implements io.ByteReader, +// as CidFromReader has to do many single-byte reads to decode varints. +// If the argument only implements io.Reader, single-byte Read calls are used instead. +func CidFromReader(r io.Reader) (int, Cid, error) { + // 64 bytes is enough for any CIDv0, + // and it's enough for most CIDv1s in practice. + // If the digest is too long, we'll allocate more. + br := toBufByteReader(r, make([]byte, 0, 64)) + + // We read the first varint, to tell if this is a CIDv0 or a CIDv1. + // The varint package wants a io.ByteReader, so we must wrap our io.Reader. + vers, err := varint.ReadUvarint(br) + if err != nil { + return len(br.dst), Undef, err + } + + // If we have a CIDv0, read the rest of the bytes and cast the buffer. + if vers == mh.SHA2_256 { + if n, err := io.ReadFull(r, br.dst[1:34]); err != nil { + return len(br.dst) + n, Undef, err + } + + br.dst = br.dst[:34] + h, err := mh.Cast(br.dst) + if err != nil { + return len(br.dst), Undef, err + } + + return len(br.dst), Cid{string(h)}, nil + } + + if vers != 1 { + return len(br.dst), Undef, fmt.Errorf("expected 1 as the cid version number, got: %d", vers) + } + + // CID block encoding multicodec. + _, err = varint.ReadUvarint(br) + if err != nil { + return len(br.dst), Undef, err + } + + // We could replace most of the code below with go-multihash's ReadMultihash. + // Note that it would save code, but prevent reusing buffers. + // Plus, we already have a ByteReader now. + mhStart := len(br.dst) + + // Multihash hash function code. + _, err = varint.ReadUvarint(br) + if err != nil { + return len(br.dst), Undef, err + } + + // Multihash digest length. + mhl, err := varint.ReadUvarint(br) + if err != nil { + return len(br.dst), Undef, err + } + + // Refuse to make large allocations to prevent OOMs due to bugs. + const maxDigestAlloc = 32 << 20 // 32MiB + if mhl > maxDigestAlloc { + return len(br.dst), Undef, fmt.Errorf("refusing to allocate %d bytes for a digest", mhl) + } + + // Fine to convert mhl to int, given maxDigestAlloc. + prefixLength := len(br.dst) + cidLength := prefixLength + int(mhl) + if cidLength > cap(br.dst) { + // If the multihash digest doesn't fit in our initial 64 bytes, + // efficiently extend the slice via append+make. + br.dst = append(br.dst, make([]byte, cidLength-cap(br.dst))...) + } else { + // The multihash digest fits inside our buffer, + // so just extend its capacity. + br.dst = br.dst[:cidLength] + } + + if n, err := io.ReadFull(r, br.dst[prefixLength:cidLength]); err != nil { + // We can't use len(br.dst) here, + // as we've only read n bytes past prefixLength. + return prefixLength + n, Undef, err + } + + // This simply ensures the multihash is valid. + // TODO: consider removing this bit, as it's probably redundant; + // for now, it helps ensure consistency with CidFromBytes. + _, _, err = mh.MHFromBytes(br.dst[mhStart:]) + if err != nil { + return len(br.dst), Undef, err + } + + return len(br.dst), Cid{string(br.dst)}, nil +} diff --git a/cid_test.go b/cid_test.go index 28fc964..161e6dc 100644 --- a/cid_test.go +++ b/cid_test.go @@ -4,10 +4,12 @@ import ( "bytes" "encoding/json" "fmt" + "io" "math/rand" "reflect" "strings" "testing" + "testing/iotest" mbase "github.com/multiformats/go-multibase" mh "github.com/multiformats/go-multihash" @@ -692,51 +694,98 @@ func TestReadCidsFromBuffer(t *testing.T) { if cur != len(buf) { t.Fatal("had trailing bytes") } -} -func TestBadCidFromBytes(t *testing.T) { - l, c, err := CidFromBytes([]byte{mh.SHA2_256, 32, 0x00}) - if err == nil { - t.Fatal("expected not-enough-bytes for V0 CidFromBytes") - } - if l != 0 { - t.Fatal("expected length=0 from bad CidFromBytes") - } - if c != Undef { - t.Fatal("expected Undef CID from bad CidFromBytes") - } + // The same, but now with CidFromReader. + // In multiple forms, to catch more io interface bugs. + for _, r := range []io.Reader{ + // implements io.ByteReader + bytes.NewReader(buf), - c, err = Decode("bafkreie5qrjvaw64n4tjm6hbnm7fnqvcssfed4whsjqxzslbd3jwhsk3mm") - if err != nil { - t.Fatal(err) - } - byts := make([]byte, c.ByteLen()) - copy(byts, c.Bytes()) - byts[1] = 0x80 // bad codec varint - byts[2] = 0x00 - l, c, err = CidFromBytes(byts) - if err == nil { - t.Fatal("expected not-enough-bytes for V1 CidFromBytes") - } - if l != 0 { - t.Fatal("expected length=0 from bad CidFromBytes") - } - if c != Undef { - t.Fatal("expected Undef CID from bad CidFromBytes") + // tiny reads, no io.ByteReader + iotest.OneByteReader(bytes.NewReader(buf)), + } { + cur = 0 + for _, expc := range cids { + n, c, err := CidFromReader(r) + if err != nil { + t.Fatal(err) + } + if c != expc { + t.Fatal("cids mismatched") + } + cur += n + } + if cur != len(buf) { + t.Fatal("had trailing bytes") + } } +} - copy(byts, c.Bytes()) - byts[2] = 0x80 // bad multihash varint - byts[3] = 0x00 - l, c, err = CidFromBytes(byts) - if err == nil { - t.Fatal("expected not-enough-bytes for V1 CidFromBytes") - } - if l != 0 { - t.Fatal("expected length=0 from bad CidFromBytes") - } - if c != Undef { - t.Fatal("expected Undef CID from bad CidFromBytes") +func TestBadCidInput(t *testing.T) { + for _, name := range []string{ + "FromBytes", + "FromReader", + } { + t.Run(name, func(t *testing.T) { + usingReader := name == "FromReader" + + fromBytes := CidFromBytes + if usingReader { + fromBytes = func(data []byte) (int, Cid, error) { + return CidFromReader(bytes.NewReader(data)) + } + } + + l, c, err := fromBytes([]byte{mh.SHA2_256, 32, 0x00}) + if err == nil { + t.Fatal("expected not-enough-bytes for V0 CID") + } + if !usingReader && l != 0 { + t.Fatal("expected length==0 from bad CID") + } else if usingReader && l == 0 { + t.Fatal("expected length!=0 from bad CID") + } + if c != Undef { + t.Fatal("expected Undef CID from bad CID") + } + + c, err = Decode("bafkreie5qrjvaw64n4tjm6hbnm7fnqvcssfed4whsjqxzslbd3jwhsk3mm") + if err != nil { + t.Fatal(err) + } + byts := make([]byte, c.ByteLen()) + copy(byts, c.Bytes()) + byts[1] = 0x80 // bad codec varint + byts[2] = 0x00 + l, c, err = fromBytes(byts) + if err == nil { + t.Fatal("expected not-enough-bytes for V1 CID") + } + if !usingReader && l != 0 { + t.Fatal("expected length==0 from bad CID") + } else if usingReader && l == 0 { + t.Fatal("expected length!=0 from bad CID") + } + if c != Undef { + t.Fatal("expected Undef CID from bad CID") + } + + copy(byts, c.Bytes()) + byts[2] = 0x80 // bad multihash varint + byts[3] = 0x00 + l, c, err = fromBytes(byts) + if err == nil { + t.Fatal("expected not-enough-bytes for V1 CID") + } + if !usingReader && l != 0 { + t.Fatal("expected length==0 from bad CID") + } else if usingReader && l == 0 { + t.Fatal("expected length!=0 from bad CID") + } + if c != Undef { + t.Fatal("expected Undef CID from bad CidFromBytes") + } + }) } }