Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: wrap parsing errors into ErrInvalidCid #150

Merged
merged 2 commits into from
Mar 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
94 changes: 58 additions & 36 deletions cid.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,32 @@ import (
// UnsupportedVersionString just holds an error message
const UnsupportedVersionString = "<unsupported cid version>"

// ErrInvalidCid is an error that indicates that a CID is invalid.
type ErrInvalidCid struct {
Err error
}

func (e ErrInvalidCid) Error() string {
return fmt.Sprintf("invalid cid: %s", e.Err)
}

func (e ErrInvalidCid) Unwrap() error {
return e.Err
}

func (e ErrInvalidCid) Is(err error) bool {
switch err.(type) {
case ErrInvalidCid, *ErrInvalidCid:
return true
default:
return false
}
}

var (
// ErrCidTooShort means that the cid passed to decode was not long
// enough to be a valid Cid
ErrCidTooShort = errors.New("cid too short")
ErrCidTooShort = ErrInvalidCid{errors.New("cid too short")}

// ErrInvalidEncoding means that selected encoding is not supported
// by this Cid version
Expand Down Expand Up @@ -90,10 +112,10 @@ func tryNewCidV0(mhash mh.Multihash) (Cid, error) {
// incorrectly detect it as CidV1 in the Version() method
dec, err := mh.Decode(mhash)
if err != nil {
return Undef, err
return Undef, ErrInvalidCid{err}
}
if dec.Code != mh.SHA2_256 || dec.Length != 32 {
return Undef, fmt.Errorf("invalid hash for cidv0 %d-%d", dec.Code, dec.Length)
return Undef, ErrInvalidCid{fmt.Errorf("invalid hash for cidv0 %d-%d", dec.Code, dec.Length)}
}
return Cid{string(mhash)}, nil
}
Expand Down Expand Up @@ -177,7 +199,7 @@ func Parse(v interface{}) (Cid, error) {
case Cid:
return v2, nil
default:
return Undef, fmt.Errorf("can't parse %+v as Cid", v2)
return Undef, ErrInvalidCid{fmt.Errorf("can't parse %+v as Cid", v2)}
}
}

Expand Down Expand Up @@ -210,15 +232,15 @@ func Decode(v string) (Cid, error) {
if len(v) == 46 && v[:2] == "Qm" {
hash, err := mh.FromB58String(v)
if err != nil {
return Undef, err
return Undef, ErrInvalidCid{err}
}

return tryNewCidV0(hash)
}

_, data, err := mbase.Decode(v)
if err != nil {
return Undef, err
return Undef, ErrInvalidCid{err}
}

return Cast(data)
Expand All @@ -240,7 +262,7 @@ func ExtractEncoding(v string) (mbase.Encoding, error) {
// check encoding is valid
_, err := mbase.NewEncoder(encoding)
if err != nil {
return -1, err
return -1, ErrInvalidCid{err}
}

return encoding, nil
Expand All @@ -260,11 +282,11 @@ func ExtractEncoding(v string) (mbase.Encoding, error) {
func Cast(data []byte) (Cid, error) {
nr, c, err := CidFromBytes(data)
if err != nil {
return Undef, err
return Undef, ErrInvalidCid{err}
}

if nr != len(data) {
return Undef, fmt.Errorf("trailing bytes in data buffer passed to cid Cast")
return Undef, ErrInvalidCid{fmt.Errorf("trailing bytes in data buffer passed to cid Cast")}
}

return c, nil
Expand Down Expand Up @@ -434,28 +456,28 @@ func (c Cid) Equals(o Cid) bool {
// UnmarshalJSON parses the JSON representation of a Cid.
func (c *Cid) UnmarshalJSON(b []byte) error {
if len(b) < 2 {
return fmt.Errorf("invalid cid json blob")
return ErrInvalidCid{fmt.Errorf("invalid cid json blob")}
}
obj := struct {
CidTarget string `json:"/"`
}{}
objptr := &obj
err := json.Unmarshal(b, &objptr)
if err != nil {
return err
return ErrInvalidCid{err}
}
if objptr == nil {
*c = Cid{}
return nil
}

if obj.CidTarget == "" {
return fmt.Errorf("cid was incorrectly formatted")
return ErrInvalidCid{fmt.Errorf("cid was incorrectly formatted")}
}

out, err := Decode(obj.CidTarget)
if err != nil {
return err
return ErrInvalidCid{err}
}

*c = out
Expand Down Expand Up @@ -542,12 +564,12 @@ func (p Prefix) Sum(data []byte) (Cid, error) {
if p.Version == 0 && (p.MhType != mh.SHA2_256 ||
(p.MhLength != 32 && p.MhLength != -1)) {

return Undef, fmt.Errorf("invalid v0 prefix")
return Undef, ErrInvalidCid{fmt.Errorf("invalid v0 prefix")}
}

hash, err := mh.Sum(data, p.MhType, length)
if err != nil {
return Undef, err
return Undef, ErrInvalidCid{err}
}

switch p.Version {
Expand All @@ -556,7 +578,7 @@ func (p Prefix) Sum(data []byte) (Cid, error) {
case 1:
return NewCidV1(p.Codec, hash), nil
default:
return Undef, fmt.Errorf("invalid cid version")
return Undef, ErrInvalidCid{fmt.Errorf("invalid cid version")}
}
}

Expand Down Expand Up @@ -586,22 +608,22 @@ func PrefixFromBytes(buf []byte) (Prefix, error) {
r := bytes.NewReader(buf)
vers, err := varint.ReadUvarint(r)
if err != nil {
return Prefix{}, err
return Prefix{}, ErrInvalidCid{err}
}

codec, err := varint.ReadUvarint(r)
if err != nil {
return Prefix{}, err
return Prefix{}, ErrInvalidCid{err}
}

mhtype, err := varint.ReadUvarint(r)
if err != nil {
return Prefix{}, err
return Prefix{}, ErrInvalidCid{err}
}

mhlen, err := varint.ReadUvarint(r)
if err != nil {
return Prefix{}, err
return Prefix{}, ErrInvalidCid{err}
}

return Prefix{
Expand All @@ -615,34 +637,34 @@ func PrefixFromBytes(buf []byte) (Prefix, error) {
func CidFromBytes(data []byte) (int, Cid, error) {
if len(data) > 2 && data[0] == mh.SHA2_256 && data[1] == 32 {
if len(data) < 34 {
return 0, Undef, fmt.Errorf("not enough bytes for cid v0")
return 0, Undef, ErrInvalidCid{fmt.Errorf("not enough bytes for cid v0")}
}

h, err := mh.Cast(data[:34])
if err != nil {
return 0, Undef, err
return 0, Undef, ErrInvalidCid{err}
}

return 34, Cid{string(h)}, nil
}

vers, n, err := varint.FromUvarint(data)
if err != nil {
return 0, Undef, err
return 0, Undef, ErrInvalidCid{err}
}

if vers != 1 {
return 0, Undef, fmt.Errorf("expected 1 as the cid version number, got: %d", vers)
return 0, Undef, ErrInvalidCid{fmt.Errorf("expected 1 as the cid version number, got: %d", vers)}
}

_, cn, err := varint.FromUvarint(data[n:])
if err != nil {
return 0, Undef, err
return 0, Undef, ErrInvalidCid{err}
}

mhnr, _, err := mh.MHFromBytes(data[n+cn:])
if err != nil {
return 0, Undef, err
return 0, Undef, ErrInvalidCid{err}
}

l := n + cn + mhnr
Expand Down Expand Up @@ -705,32 +727,32 @@ func CidFromReader(r io.Reader) (int, Cid, error) {
// The varint package wants a io.ByteReader, so we must wrap our io.Reader.
vers, err := varint.ReadUvarint(br)
if err != nil {
return len(br.dst), Undef, err
return len(br.dst), Undef, ErrInvalidCid{err}
}

// If we have a CIDv0, read the rest of the bytes and cast the buffer.
if vers == mh.SHA2_256 {
if n, err := io.ReadFull(r, br.dst[1:34]); err != nil {
return len(br.dst) + n, Undef, err
return len(br.dst) + n, Undef, ErrInvalidCid{err}
}

br.dst = br.dst[:34]
h, err := mh.Cast(br.dst)
if err != nil {
return len(br.dst), Undef, err
return len(br.dst), Undef, ErrInvalidCid{err}
}

return len(br.dst), Cid{string(h)}, nil
}

if vers != 1 {
return len(br.dst), Undef, fmt.Errorf("expected 1 as the cid version number, got: %d", vers)
return len(br.dst), Undef, ErrInvalidCid{fmt.Errorf("expected 1 as the cid version number, got: %d", vers)}
}

// CID block encoding multicodec.
_, err = varint.ReadUvarint(br)
if err != nil {
return len(br.dst), Undef, err
return len(br.dst), Undef, ErrInvalidCid{err}
}

// We could replace most of the code below with go-multihash's ReadMultihash.
Expand All @@ -741,19 +763,19 @@ func CidFromReader(r io.Reader) (int, Cid, error) {
// Multihash hash function code.
_, err = varint.ReadUvarint(br)
if err != nil {
return len(br.dst), Undef, err
return len(br.dst), Undef, ErrInvalidCid{err}
}

// Multihash digest length.
mhl, err := varint.ReadUvarint(br)
if err != nil {
return len(br.dst), Undef, err
return len(br.dst), Undef, ErrInvalidCid{err}
}

// Refuse to make large allocations to prevent OOMs due to bugs.
const maxDigestAlloc = 32 << 20 // 32MiB
if mhl > maxDigestAlloc {
return len(br.dst), Undef, fmt.Errorf("refusing to allocate %d bytes for a digest", mhl)
return len(br.dst), Undef, ErrInvalidCid{fmt.Errorf("refusing to allocate %d bytes for a digest", mhl)}
}

// Fine to convert mhl to int, given maxDigestAlloc.
Expand All @@ -772,15 +794,15 @@ func CidFromReader(r io.Reader) (int, Cid, error) {
if n, err := io.ReadFull(r, br.dst[prefixLength:cidLength]); err != nil {
// We can't use len(br.dst) here,
// as we've only read n bytes past prefixLength.
return prefixLength + n, Undef, err
return prefixLength + n, Undef, ErrInvalidCid{err}
}

// This simply ensures the multihash is valid.
// TODO: consider removing this bit, as it's probably redundant;
// for now, it helps ensure consistency with CidFromBytes.
_, _, err = mh.MHFromBytes(br.dst[mhStart:])
if err != nil {
return len(br.dst), Undef, err
return len(br.dst), Undef, ErrInvalidCid{err}
}

return len(br.dst), Cid{string(br.dst)}, nil
Expand Down