Skip to content

Commit

Permalink
implement CidFromReader
Browse files Browse the repository at this point in the history
And reuse a CidFromBytes test for it, which includes both CIDv0 and
CIDv1 cases as inputs.

Fixes #126.
  • Loading branch information
mvdan committed Jul 2, 2021
1 parent 8e9280d commit 41f2377
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 0 deletions.
142 changes: 142 additions & 0 deletions cid.go
Original file line number Diff line number Diff line change
Expand Up @@ -680,3 +680,145 @@ func CidFromBytes(data []byte) (int, Cid, error) {

return l, Cid{string(data[0:l])}, nil
}

func toBufByteReader(r io.Reader, dst []byte) *bufByteReader {
// If the reader already implements ByteReader, use it directly.
// Otherwise, use a fallback that does 1-byte Reads.
if br, ok := r.(io.ByteReader); ok {
return &bufByteReader{direct: br, dst: dst}
}
return &bufByteReader{fallback: r, dst: dst}
}

type bufByteReader struct {
direct io.ByteReader
fallback io.Reader

consumed int
dst []byte
}

func (r *bufByteReader) ReadByte() (byte, error) {
// We still have some of the initial bytes to use.
if r.consumed < len(r.dst) {
b := r.dst[r.consumed]
r.consumed++
return b, nil
}
r.consumed++

// The underlying reader has ReadByte; use it.
if br := r.direct; br != nil {
b, err := br.ReadByte()
if err != nil {
return 0, err
}
r.dst = append(r.dst, b)
return b, nil
}

// Fall back to a one-byte Read.
var p [1]byte
if _, err := io.ReadFull(r.fallback, p[:]); err != nil {
return 0, err
}
r.dst = append(r.dst, p[0])
return p[0], nil
}

// CidFromReader reads a precise number of bytes for a CID from a given reader.
// It returns the number of bytes read, the CID, and any error encountered.
// The number of bytes read is accurate even if a non-nil error is returned.
//
// It's recommended to supply a reader that buffers and implements io.ByteReader,
// as CidFromReader has to do many single-byte reads to decode varints.
// If the argument only implements io.Reader, single-byte Read calls are used instead.
func CidFromReader(r io.Reader) (int, Cid, error) {
// 64 bytes is enough for any CIDv0,
// and it's enough for most CIDv1s in practice.
// If the digest is too long, we'll allocate more.
buf := make([]byte, 0, 64)

// We read two bytes, to tell if this is a CIDv0 or a CIDv1.
if n, err := io.ReadFull(r, buf[:2]); err != nil {
return n, Undef, err
}
buf = buf[:2]

// If we have a CIDv0, read the rest of the bytes and cast the buffer.
if buf[0] == mh.SHA2_256 && buf[1] == 32 {
if n, err := io.ReadFull(r, buf[2:34]); err != nil {
return len(buf) + n, Undef, err
}

buf = buf[:34]
h, err := mh.Cast(buf)
if err != nil {
return len(buf), Undef, err
}

return len(buf), Cid{string(h)}, nil
}

// The varint package wants a io.ByteReader, so we must wrap our io.Reader.
// Note that we already read two bytes, so bufByteReader uses those first.
// After those two bytes, bufByteReader appends the read bytes to br.dst.
br := toBufByteReader(r, buf[:2])
vers, err := varint.ReadUvarint(br)
if err != nil {
return len(br.dst), Undef, err
}

if vers != 1 {
return len(br.dst), Undef, fmt.Errorf("expected 1 as the cid version number, got: %d", vers)
}

// CID block encoding multicodec.
_, err = varint.ReadUvarint(br)
if err != nil {
return len(br.dst), Undef, err
}

// We could replace most of the code below with go-multihash's ReadMultihash.
// Note that it would save code, but prevent reusing buffers.
// Plus, we already have a ByteReader now.
mhStart := len(br.dst)

// Multihash hash function code.
_, err = varint.ReadUvarint(br)
if err != nil {
return len(br.dst), Undef, err
}

// Multihash digest length.
mhl, err := varint.ReadUvarint(br)
if err != nil {
return len(br.dst), Undef, err
}

// Update buf's length.
// We're not reading single bytes beyond this point.
buf = br.dst
br = nil

// Multihash digest; might be too long, so allocate.
// Refuse to make large allocations to prevent OOMs due to bugs.
// TODO: reuse buf if it has enough space
const maxDigestAlloc = 32 << 20 // 32MiB
if mhl > maxDigestAlloc {
return len(buf), Undef, fmt.Errorf("refusing to allocate %d bytes for a digest", mhl)
}
digest := make([]byte, int(mhl))
if n, err := io.ReadFull(r, digest); err != nil {
return len(buf) + n, Undef, err
}
buf = append(buf, digest...)

// This simply ensures the multihash is valid.
_, _, err = mh.MHFromBytes(buf[mhStart:])
if err != nil {
return len(buf), Undef, err
}

return len(buf), Cid{string(buf)}, nil
}
27 changes: 27 additions & 0 deletions cid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import (
"bytes"
"encoding/json"
"fmt"
"io"
"math/rand"
"reflect"
"strings"
"testing"
"testing/iotest"

mbase "github.com/multiformats/go-multibase"
mh "github.com/multiformats/go-multihash"
Expand Down Expand Up @@ -692,6 +694,31 @@ func TestReadCidsFromBuffer(t *testing.T) {
if cur != len(buf) {
t.Fatal("had trailing bytes")
}

// The same, but now with CidFromReader.
// In multiple forms, to catch more io interface bugs.
for _, r := range []io.Reader{
// implements io.ByteReader
bytes.NewReader(buf),

// tiny reads, no io.ByteReader
iotest.OneByteReader(bytes.NewReader(buf)),
} {
cur = 0
for _, expc := range cids {
n, c, err := CidFromReader(r)
if err != nil {
t.Fatal(err)
}
if c != expc {
t.Fatal("cids mismatched")
}
cur += n
}
if cur != len(buf) {
t.Fatal("had trailing bytes")
}
}
}

func TestBadCidFromBytes(t *testing.T) {
Expand Down

0 comments on commit 41f2377

Please sign in to comment.