From fb270a3c0b8f004e26ee9a9eca4d2e2938546ef3 Mon Sep 17 00:00:00 2001 From: Henrique Dias Date: Wed, 15 Mar 2023 16:10:11 +0100 Subject: [PATCH 1/2] feat: wrap parsing errors into ErrInvalidCid --- cid.go | 94 +++++++++++++++++++++++--------------- cid_test.go | 129 ++++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 184 insertions(+), 39 deletions(-) diff --git a/cid.go b/cid.go index 651c94d..ae3d1fd 100644 --- a/cid.go +++ b/cid.go @@ -37,10 +37,32 @@ import ( // UnsupportedVersionString just holds an error message const UnsupportedVersionString = "" +// ErrInvalidCid is an error that indicates that a CID is invalid. +type ErrInvalidCid struct { + Err error +} + +func (e ErrInvalidCid) Error() string { + return fmt.Sprintf("invalid cid: %s", e.Err) +} + +func (e ErrInvalidCid) Unwrap() error { + return e.Err +} + +func (e ErrInvalidCid) Is(err error) bool { + switch err.(type) { + case ErrInvalidCid, *ErrInvalidCid: + return true + default: + return false + } +} + var ( // ErrCidTooShort means that the cid passed to decode was not long // enough to be a valid Cid - ErrCidTooShort = errors.New("cid too short") + ErrCidTooShort = ErrInvalidCid{errors.New("cid too short")} // ErrInvalidEncoding means that selected encoding is not supported // by this Cid version @@ -90,10 +112,10 @@ func tryNewCidV0(mhash mh.Multihash) (Cid, error) { // incorrectly detect it as CidV1 in the Version() method dec, err := mh.Decode(mhash) if err != nil { - return Undef, err + return Undef, ErrInvalidCid{err} } if dec.Code != mh.SHA2_256 || dec.Length != 32 { - return Undef, fmt.Errorf("invalid hash for cidv0 %d-%d", dec.Code, dec.Length) + return Undef, ErrInvalidCid{fmt.Errorf("invalid hash for cidv0 %d-%d", dec.Code, dec.Length)} } return Cid{string(mhash)}, nil } @@ -177,7 +199,7 @@ func Parse(v interface{}) (Cid, error) { case Cid: return v2, nil default: - return Undef, fmt.Errorf("can't parse %+v as Cid", v2) + return Undef, ErrInvalidCid{fmt.Errorf("can't parse %+v as Cid", v2)} } } @@ -210,7 +232,7 @@ func Decode(v string) (Cid, error) { if len(v) == 46 && v[:2] == "Qm" { hash, err := mh.FromB58String(v) if err != nil { - return Undef, err + return Undef, ErrInvalidCid{err} } return tryNewCidV0(hash) @@ -218,7 +240,7 @@ func Decode(v string) (Cid, error) { _, data, err := mbase.Decode(v) if err != nil { - return Undef, err + return Undef, ErrInvalidCid{err} } return Cast(data) @@ -240,7 +262,7 @@ func ExtractEncoding(v string) (mbase.Encoding, error) { // check encoding is valid _, err := mbase.NewEncoder(encoding) if err != nil { - return -1, err + return -1, ErrInvalidCid{err} } return encoding, nil @@ -260,11 +282,11 @@ func ExtractEncoding(v string) (mbase.Encoding, error) { func Cast(data []byte) (Cid, error) { nr, c, err := CidFromBytes(data) if err != nil { - return Undef, err + return Undef, ErrInvalidCid{err} } if nr != len(data) { - return Undef, fmt.Errorf("trailing bytes in data buffer passed to cid Cast") + return Undef, ErrInvalidCid{fmt.Errorf("trailing bytes in data buffer passed to cid Cast")} } return c, nil @@ -434,7 +456,7 @@ func (c Cid) Equals(o Cid) bool { // UnmarshalJSON parses the JSON representation of a Cid. func (c *Cid) UnmarshalJSON(b []byte) error { if len(b) < 2 { - return fmt.Errorf("invalid cid json blob") + return ErrInvalidCid{fmt.Errorf("invalid cid json blob")} } obj := struct { CidTarget string `json:"/"` @@ -442,7 +464,7 @@ func (c *Cid) UnmarshalJSON(b []byte) error { objptr := &obj err := json.Unmarshal(b, &objptr) if err != nil { - return err + return ErrInvalidCid{err} } if objptr == nil { *c = Cid{} @@ -450,12 +472,12 @@ func (c *Cid) UnmarshalJSON(b []byte) error { } if obj.CidTarget == "" { - return fmt.Errorf("cid was incorrectly formatted") + return ErrInvalidCid{fmt.Errorf("cid was incorrectly formatted")} } out, err := Decode(obj.CidTarget) if err != nil { - return err + return ErrInvalidCid{err} } *c = out @@ -542,12 +564,12 @@ func (p Prefix) Sum(data []byte) (Cid, error) { if p.Version == 0 && (p.MhType != mh.SHA2_256 || (p.MhLength != 32 && p.MhLength != -1)) { - return Undef, fmt.Errorf("invalid v0 prefix") + return Undef, ErrInvalidCid{fmt.Errorf("invalid v0 prefix")} } hash, err := mh.Sum(data, p.MhType, length) if err != nil { - return Undef, err + return Undef, ErrInvalidCid{err} } switch p.Version { @@ -556,7 +578,7 @@ func (p Prefix) Sum(data []byte) (Cid, error) { case 1: return NewCidV1(p.Codec, hash), nil default: - return Undef, fmt.Errorf("invalid cid version") + return Undef, ErrInvalidCid{fmt.Errorf("invalid cid version")} } } @@ -586,22 +608,22 @@ func PrefixFromBytes(buf []byte) (Prefix, error) { r := bytes.NewReader(buf) vers, err := varint.ReadUvarint(r) if err != nil { - return Prefix{}, err + return Prefix{}, ErrInvalidCid{err} } codec, err := varint.ReadUvarint(r) if err != nil { - return Prefix{}, err + return Prefix{}, ErrInvalidCid{err} } mhtype, err := varint.ReadUvarint(r) if err != nil { - return Prefix{}, err + return Prefix{}, ErrInvalidCid{err} } mhlen, err := varint.ReadUvarint(r) if err != nil { - return Prefix{}, err + return Prefix{}, ErrInvalidCid{err} } return Prefix{ @@ -615,12 +637,12 @@ func PrefixFromBytes(buf []byte) (Prefix, error) { func CidFromBytes(data []byte) (int, Cid, error) { if len(data) > 2 && data[0] == mh.SHA2_256 && data[1] == 32 { if len(data) < 34 { - return 0, Undef, fmt.Errorf("not enough bytes for cid v0") + return 0, Undef, ErrInvalidCid{fmt.Errorf("not enough bytes for cid v0")} } h, err := mh.Cast(data[:34]) if err != nil { - return 0, Undef, err + return 0, Undef, ErrInvalidCid{err} } return 34, Cid{string(h)}, nil @@ -628,21 +650,21 @@ func CidFromBytes(data []byte) (int, Cid, error) { vers, n, err := varint.FromUvarint(data) if err != nil { - return 0, Undef, err + return 0, Undef, ErrInvalidCid{err} } if vers != 1 { - return 0, Undef, fmt.Errorf("expected 1 as the cid version number, got: %d", vers) + return 0, Undef, ErrInvalidCid{fmt.Errorf("expected 1 as the cid version number, got: %d", vers)} } _, cn, err := varint.FromUvarint(data[n:]) if err != nil { - return 0, Undef, err + return 0, Undef, ErrInvalidCid{err} } mhnr, _, err := mh.MHFromBytes(data[n+cn:]) if err != nil { - return 0, Undef, err + return 0, Undef, ErrInvalidCid{err} } l := n + cn + mhnr @@ -705,32 +727,32 @@ func CidFromReader(r io.Reader) (int, Cid, error) { // The varint package wants a io.ByteReader, so we must wrap our io.Reader. vers, err := varint.ReadUvarint(br) if err != nil { - return len(br.dst), Undef, err + return len(br.dst), Undef, ErrInvalidCid{err} } // If we have a CIDv0, read the rest of the bytes and cast the buffer. if vers == mh.SHA2_256 { if n, err := io.ReadFull(r, br.dst[1:34]); err != nil { - return len(br.dst) + n, Undef, err + return len(br.dst) + n, Undef, ErrInvalidCid{err} } br.dst = br.dst[:34] h, err := mh.Cast(br.dst) if err != nil { - return len(br.dst), Undef, err + return len(br.dst), Undef, ErrInvalidCid{err} } return len(br.dst), Cid{string(h)}, nil } if vers != 1 { - return len(br.dst), Undef, fmt.Errorf("expected 1 as the cid version number, got: %d", vers) + return len(br.dst), Undef, ErrInvalidCid{fmt.Errorf("expected 1 as the cid version number, got: %d", vers)} } // CID block encoding multicodec. _, err = varint.ReadUvarint(br) if err != nil { - return len(br.dst), Undef, err + return len(br.dst), Undef, ErrInvalidCid{err} } // We could replace most of the code below with go-multihash's ReadMultihash. @@ -741,19 +763,19 @@ func CidFromReader(r io.Reader) (int, Cid, error) { // Multihash hash function code. _, err = varint.ReadUvarint(br) if err != nil { - return len(br.dst), Undef, err + return len(br.dst), Undef, ErrInvalidCid{err} } // Multihash digest length. mhl, err := varint.ReadUvarint(br) if err != nil { - return len(br.dst), Undef, err + return len(br.dst), Undef, ErrInvalidCid{err} } // Refuse to make large allocations to prevent OOMs due to bugs. const maxDigestAlloc = 32 << 20 // 32MiB if mhl > maxDigestAlloc { - return len(br.dst), Undef, fmt.Errorf("refusing to allocate %d bytes for a digest", mhl) + return len(br.dst), Undef, ErrInvalidCid{fmt.Errorf("refusing to allocate %d bytes for a digest", mhl)} } // Fine to convert mhl to int, given maxDigestAlloc. @@ -772,7 +794,7 @@ func CidFromReader(r io.Reader) (int, Cid, error) { if n, err := io.ReadFull(r, br.dst[prefixLength:cidLength]); err != nil { // We can't use len(br.dst) here, // as we've only read n bytes past prefixLength. - return prefixLength + n, Undef, err + return prefixLength + n, Undef, ErrInvalidCid{err} } // This simply ensures the multihash is valid. @@ -780,7 +802,7 @@ func CidFromReader(r io.Reader) (int, Cid, error) { // for now, it helps ensure consistency with CidFromBytes. _, _, err = mh.MHFromBytes(br.dst[mhStart:]) if err != nil { - return len(br.dst), Undef, err + return len(br.dst), Undef, ErrInvalidCid{err} } return len(br.dst), Cid{string(br.dst)}, nil diff --git a/cid_test.go b/cid_test.go index 930e194..31989da 100644 --- a/cid_test.go +++ b/cid_test.go @@ -4,6 +4,7 @@ import ( "bytes" crand "crypto/rand" "encoding/json" + "errors" "fmt" "io" "math/rand" @@ -162,6 +163,9 @@ func TestBasesMarshaling(t *testing.T) { if err == nil { t.Fatal("expected too-short error from ExtractEncoding") } + if !errors.Is(err, ErrInvalidCid{}) { + t.Fatal("expected error to be ErrInvalidCid") + } if ee != -1 { t.Fatal("expected -1 from too-short ExtractEncoding") } @@ -227,6 +231,9 @@ func TestEmptyString(t *testing.T) { if err == nil { t.Fatal("shouldnt be able to parse an empty cid") } + if !errors.Is(err, ErrInvalidCid{}) { + t.Fatal("error must be ErrInvalidCid") + } } func TestV0Handling(t *testing.T) { @@ -282,6 +289,9 @@ func TestV0ErrorCases(t *testing.T) { if err == nil { t.Fatal("should have failed to decode that ref") } + if !errors.Is(err, ErrInvalidCid{}) { + t.Fatal("error must be ErrInvalidCid") + } } func TestNewPrefixV1(t *testing.T) { @@ -372,6 +382,9 @@ func TestInvalidV0Prefix(t *testing.T) { if err == nil { t.Fatalf("should error (index %d)", i) } + if !errors.Is(err, ErrInvalidCid{}) { + t.Fatal("expected error to be ErrInvalidCid") + } } } @@ -381,6 +394,9 @@ func TestBadPrefix(t *testing.T) { if err == nil { t.Fatalf("expected error on v3 prefix Sum") } + if !errors.Is(err, ErrInvalidCid{}) { + t.Fatal("expected error to be ErrInvalidCid") + } } func TestPrefixRoundtrip(t *testing.T) { @@ -417,18 +433,30 @@ func TestBadPrefixFromBytes(t *testing.T) { if err == nil { t.Fatal("expected error for bad byte 0") } + if !errors.Is(err, ErrInvalidCid{}) { + t.Fatal("expected error to be ErrInvalidCid") + } _, err = PrefixFromBytes([]byte{0x01, 0x80}) if err == nil { t.Fatal("expected error for bad byte 1") } + if !errors.Is(err, ErrInvalidCid{}) { + t.Fatal("expected error to be ErrInvalidCid") + } _, err = PrefixFromBytes([]byte{0x01, 0x01, 0x80}) if err == nil { t.Fatal("expected error for bad byte 2") } + if !errors.Is(err, ErrInvalidCid{}) { + t.Fatal("expected error to be ErrInvalidCid") + } _, err = PrefixFromBytes([]byte{0x01, 0x01, 0x01, 0x80}) if err == nil { t.Fatal("expected error for bad byte 3") } + if !errors.Is(err, ErrInvalidCid{}) { + t.Fatal("expected error to be ErrInvalidCid") + } } func Test16BytesVarint(t *testing.T) { @@ -455,6 +483,9 @@ func TestParse(t *testing.T) { if !strings.Contains(err.Error(), "can't parse 123 as Cid") { t.Fatalf("expected int error, got %s", err.Error()) } + if !errors.Is(err, ErrInvalidCid{}) { + t.Fatalf("expected ErrInvalidCid, got %s", err.Error()) + } theHash := "QmdfTbBqBPQ7VNxZEYEj14VmRuZBkqFbiwReogJgS1zR1n" h, err := mh.FromB58String(theHash) @@ -572,17 +603,29 @@ func TestJsonRoundTrip(t *testing.T) { t.Fatal("cids not equal for Cid") } - if err = actual2.UnmarshalJSON([]byte("1")); err == nil { + err = actual2.UnmarshalJSON([]byte("1")) + if err == nil { t.Fatal("expected error for too-short JSON") } + if !errors.Is(err, ErrInvalidCid{}) { + t.Fatal("expected error to be ErrInvalidCid") + } - if err = actual2.UnmarshalJSON([]byte(`{"nope":"nope"}`)); err == nil { + err = actual2.UnmarshalJSON([]byte(`{"nope":"nope"}`)) + if err == nil { t.Fatal("expected error for bad CID JSON") } + if !errors.Is(err, ErrInvalidCid{}) { + t.Fatal("expected error to be ErrInvalidCid") + } - if err = actual2.UnmarshalJSON([]byte(`bad "" json!`)); err == nil { + err = actual2.UnmarshalJSON([]byte(`bad "" json!`)) + if err == nil { t.Fatal("expected error for bad JSON") } + if !errors.Is(err, ErrInvalidCid{}) { + t.Fatal("expected error to be ErrInvalidCid") + } var actual3 Cid enc, err = actual3.MarshalJSON() @@ -749,6 +792,9 @@ func TestBadParse(t *testing.T) { if err == nil { t.Fatal("expected to fail to parse an invalid CIDv1 CID") } + if !errors.Is(err, ErrInvalidCid{}) { + t.Fatal("error must be ErrInvalidCid") + } } func TestLoggable(t *testing.T) { @@ -763,3 +809,80 @@ func TestLoggable(t *testing.T) { t.Fatalf("did not get expected loggable form (got %v)", actual) } } + +func TestErrInvalidCidIs(t *testing.T) { + for i, test := range []struct { + err error + target error + }{ + {&ErrInvalidCid{}, ErrInvalidCid{}}, + {ErrInvalidCid{}, &ErrInvalidCid{}}, + {ErrInvalidCid{}, ErrInvalidCid{}}, + {&ErrInvalidCid{}, &ErrInvalidCid{}}, + } { + if !errors.Is(test.err, test.target) { + t.Fatalf("expected error to be ErrInvalidCid, case %d", i) + } + } +} + +func TestErrInvalidCid(t *testing.T) { + run := func(err error) { + if err == nil { + t.Fatal("expected error") + } + + if !strings.HasPrefix(err.Error(), "invalid cid: ") { + t.Fatal(`expected error message to contain "invalid cid: "`) + } + + is := errors.Is(err, ErrInvalidCid{}) + if !is { + t.Fatal("expected error to be ErrInvalidCid") + } + + if !errors.Is(err, &ErrInvalidCid{}) { + t.Fatal("expected error to be &ErrInvalidCid") + } + } + + _, err := Decode("") + run(err) + + _, err = Decode("not-a-cid") + run(err) + + _, err = Decode("bafyInvalid") + run(err) + + _, err = Decode("QmdfTbBqBPQ7VNxZEYEj14VmRuZBkqFbiwReogJgS1zIII") + run(err) + + _, err = Cast([]byte("invalid")) + run(err) + + _, err = Parse("not-a-cid") + run(err) + + _, err = Parse("bafyInvalid") + run(err) + + _, err = Parse("QmdfTbBqBPQ7VNxZEYEj14VmRuZBkqFbiwReogJgS1zIII") + run(err) + + _, err = Parse(123) + run(err) + + _, _, err = CidFromBytes([]byte("invalid")) + run(err) + + _, err = Prefix{}.Sum([]byte("data")) + run(err) + + _, err = PrefixFromBytes([]byte{0x80}) + run(err) + + _, err = ExtractEncoding("invalid ") + run(err) + +} From c764ccc641f64baaebc680638f566fd9dbb009db Mon Sep 17 00:00:00 2001 From: Henrique Dias Date: Mon, 20 Mar 2023 08:23:12 +0100 Subject: [PATCH 2/2] chore: version 0.4.0 --- version.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.json b/version.json index 908483a..372b6ea 100644 --- a/version.json +++ b/version.json @@ -1,3 +1,3 @@ { - "version": "v0.3.2" + "version": "v0.4.0" }