From 65e3baa1cbeb24cad533afb47266246f29f15283 Mon Sep 17 00:00:00 2001 From: Henrique Dias Date: Wed, 15 Mar 2023 16:10:11 +0100 Subject: [PATCH] feat: wrap parsing errors into ErrInvalidCid --- cid.go | 72 ++++++++++++++++++++++++++++++++++------------------- cid_test.go | 22 ++++++++++++++++ 2 files changed, 69 insertions(+), 25 deletions(-) diff --git a/cid.go b/cid.go index 651c94d..9fc552c 100644 --- a/cid.go +++ b/cid.go @@ -37,10 +37,32 @@ import ( // UnsupportedVersionString just holds an error message const UnsupportedVersionString = "" +// ErrInvalidCid is an error that indicates that a CID is invalid. +type ErrInvalidCid struct { + Err error +} + +func (e *ErrInvalidCid) Error() string { + return fmt.Sprintf("invalid cid: %s", e.Err) +} + +func (e *ErrInvalidCid) Unwrap() error { + return e.Err +} + +func (e *ErrInvalidCid) Is(err error) bool { + switch err.(type) { + case *ErrInvalidCid: + return true + default: + return false + } +} + var ( // ErrCidTooShort means that the cid passed to decode was not long // enough to be a valid Cid - ErrCidTooShort = errors.New("cid too short") + ErrCidTooShort = &ErrInvalidCid{errors.New("cid too short")} // ErrInvalidEncoding means that selected encoding is not supported // by this Cid version @@ -90,10 +112,10 @@ func tryNewCidV0(mhash mh.Multihash) (Cid, error) { // incorrectly detect it as CidV1 in the Version() method dec, err := mh.Decode(mhash) if err != nil { - return Undef, err + return Undef, &ErrInvalidCid{err} } if dec.Code != mh.SHA2_256 || dec.Length != 32 { - return Undef, fmt.Errorf("invalid hash for cidv0 %d-%d", dec.Code, dec.Length) + return Undef, &ErrInvalidCid{fmt.Errorf("invalid hash for cidv0 %d-%d", dec.Code, dec.Length)} } return Cid{string(mhash)}, nil } @@ -177,7 +199,7 @@ func Parse(v interface{}) (Cid, error) { case Cid: return v2, nil default: - return Undef, fmt.Errorf("can't parse %+v as Cid", v2) + return Undef, &ErrInvalidCid{fmt.Errorf("can't parse %+v as Cid", v2)} } } @@ -210,7 +232,7 @@ func Decode(v string) (Cid, error) { if len(v) == 46 && v[:2] == "Qm" { hash, err := mh.FromB58String(v) if err != nil { - return Undef, err + return Undef, &ErrInvalidCid{err} } return tryNewCidV0(hash) @@ -218,7 +240,7 @@ func Decode(v string) (Cid, error) { _, data, err := mbase.Decode(v) if err != nil { - return Undef, err + return Undef, &ErrInvalidCid{err} } return Cast(data) @@ -240,7 +262,7 @@ func ExtractEncoding(v string) (mbase.Encoding, error) { // check encoding is valid _, err := mbase.NewEncoder(encoding) if err != nil { - return -1, err + return -1, &ErrInvalidCid{err} } return encoding, nil @@ -260,11 +282,11 @@ func ExtractEncoding(v string) (mbase.Encoding, error) { func Cast(data []byte) (Cid, error) { nr, c, err := CidFromBytes(data) if err != nil { - return Undef, err + return Undef, &ErrInvalidCid{err} } if nr != len(data) { - return Undef, fmt.Errorf("trailing bytes in data buffer passed to cid Cast") + return Undef, &ErrInvalidCid{fmt.Errorf("trailing bytes in data buffer passed to cid Cast")} } return c, nil @@ -615,12 +637,12 @@ func PrefixFromBytes(buf []byte) (Prefix, error) { func CidFromBytes(data []byte) (int, Cid, error) { if len(data) > 2 && data[0] == mh.SHA2_256 && data[1] == 32 { if len(data) < 34 { - return 0, Undef, fmt.Errorf("not enough bytes for cid v0") + return 0, Undef, &ErrInvalidCid{fmt.Errorf("not enough bytes for cid v0")} } h, err := mh.Cast(data[:34]) if err != nil { - return 0, Undef, err + return 0, Undef, &ErrInvalidCid{err} } return 34, Cid{string(h)}, nil @@ -628,21 +650,21 @@ func CidFromBytes(data []byte) (int, Cid, error) { vers, n, err := varint.FromUvarint(data) if err != nil { - return 0, Undef, err + return 0, Undef, &ErrInvalidCid{err} } if vers != 1 { - return 0, Undef, fmt.Errorf("expected 1 as the cid version number, got: %d", vers) + return 0, Undef, &ErrInvalidCid{fmt.Errorf("expected 1 as the cid version number, got: %d", vers)} } _, cn, err := varint.FromUvarint(data[n:]) if err != nil { - return 0, Undef, err + return 0, Undef, &ErrInvalidCid{err} } mhnr, _, err := mh.MHFromBytes(data[n+cn:]) if err != nil { - return 0, Undef, err + return 0, Undef, &ErrInvalidCid{err} } l := n + cn + mhnr @@ -705,32 +727,32 @@ func CidFromReader(r io.Reader) (int, Cid, error) { // The varint package wants a io.ByteReader, so we must wrap our io.Reader. vers, err := varint.ReadUvarint(br) if err != nil { - return len(br.dst), Undef, err + return len(br.dst), Undef, &ErrInvalidCid{err} } // If we have a CIDv0, read the rest of the bytes and cast the buffer. if vers == mh.SHA2_256 { if n, err := io.ReadFull(r, br.dst[1:34]); err != nil { - return len(br.dst) + n, Undef, err + return len(br.dst) + n, Undef, &ErrInvalidCid{err} } br.dst = br.dst[:34] h, err := mh.Cast(br.dst) if err != nil { - return len(br.dst), Undef, err + return len(br.dst), Undef, &ErrInvalidCid{err} } return len(br.dst), Cid{string(h)}, nil } if vers != 1 { - return len(br.dst), Undef, fmt.Errorf("expected 1 as the cid version number, got: %d", vers) + return len(br.dst), Undef, &ErrInvalidCid{fmt.Errorf("expected 1 as the cid version number, got: %d", vers)} } // CID block encoding multicodec. _, err = varint.ReadUvarint(br) if err != nil { - return len(br.dst), Undef, err + return len(br.dst), Undef, &ErrInvalidCid{err} } // We could replace most of the code below with go-multihash's ReadMultihash. @@ -741,19 +763,19 @@ func CidFromReader(r io.Reader) (int, Cid, error) { // Multihash hash function code. _, err = varint.ReadUvarint(br) if err != nil { - return len(br.dst), Undef, err + return len(br.dst), Undef, &ErrInvalidCid{err} } // Multihash digest length. mhl, err := varint.ReadUvarint(br) if err != nil { - return len(br.dst), Undef, err + return len(br.dst), Undef, &ErrInvalidCid{err} } // Refuse to make large allocations to prevent OOMs due to bugs. const maxDigestAlloc = 32 << 20 // 32MiB if mhl > maxDigestAlloc { - return len(br.dst), Undef, fmt.Errorf("refusing to allocate %d bytes for a digest", mhl) + return len(br.dst), Undef, &ErrInvalidCid{fmt.Errorf("refusing to allocate %d bytes for a digest", mhl)} } // Fine to convert mhl to int, given maxDigestAlloc. @@ -772,7 +794,7 @@ func CidFromReader(r io.Reader) (int, Cid, error) { if n, err := io.ReadFull(r, br.dst[prefixLength:cidLength]); err != nil { // We can't use len(br.dst) here, // as we've only read n bytes past prefixLength. - return prefixLength + n, Undef, err + return prefixLength + n, Undef, &ErrInvalidCid{err} } // This simply ensures the multihash is valid. @@ -780,7 +802,7 @@ func CidFromReader(r io.Reader) (int, Cid, error) { // for now, it helps ensure consistency with CidFromBytes. _, _, err = mh.MHFromBytes(br.dst[mhStart:]) if err != nil { - return len(br.dst), Undef, err + return len(br.dst), Undef, &ErrInvalidCid{err} } return len(br.dst), Cid{string(br.dst)}, nil diff --git a/cid_test.go b/cid_test.go index 930e194..4fde0b3 100644 --- a/cid_test.go +++ b/cid_test.go @@ -4,6 +4,7 @@ import ( "bytes" crand "crypto/rand" "encoding/json" + "errors" "fmt" "io" "math/rand" @@ -227,6 +228,9 @@ func TestEmptyString(t *testing.T) { if err == nil { t.Fatal("shouldnt be able to parse an empty cid") } + if !errors.Is(err, &ErrInvalidCid{}) { + t.Fatal("error must be ErrInvalidCid") + } } func TestV0Handling(t *testing.T) { @@ -282,6 +286,9 @@ func TestV0ErrorCases(t *testing.T) { if err == nil { t.Fatal("should have failed to decode that ref") } + if !errors.Is(err, &ErrInvalidCid{}) { + t.Fatal("error must be ErrInvalidCid") + } } func TestNewPrefixV1(t *testing.T) { @@ -749,6 +756,9 @@ func TestBadParse(t *testing.T) { if err == nil { t.Fatal("expected to fail to parse an invalid CIDv1 CID") } + if !errors.Is(err, &ErrInvalidCid{}) { + t.Fatal("error must be ErrInvalidCid") + } } func TestLoggable(t *testing.T) { @@ -763,3 +773,15 @@ func TestLoggable(t *testing.T) { t.Fatalf("did not get expected loggable form (got %v)", actual) } } + +func TestErrInvalidCid(t *testing.T) { + _, err := Decode("not-a-cid") + if err == nil { + t.Fatal("expected error") + } + + is := errors.Is(err, &ErrInvalidCid{}) + if !is { + t.Fatal("expected error to be ErrInvalidCid") + } +}