diff --git a/allocate_go119_test.go b/allocate_go119_test.go new file mode 100644 index 0000000..984cb67 --- /dev/null +++ b/allocate_go119_test.go @@ -0,0 +1,11 @@ +//go:build !go1.20 + +package multihash + +import "testing" + +func mustNotAllocateMore(_ *testing.T, _ float64, f func()) { + // the compiler isn't able to detect our outlined stack allocation on before + // 1.20 so let's not test for it. We don't mind if outdated versions are slightly slower. + f() +} diff --git a/allocate_go120_test.go b/allocate_go120_test.go new file mode 100644 index 0000000..78f1286 --- /dev/null +++ b/allocate_go120_test.go @@ -0,0 +1,12 @@ +//go:build go1.20 + +package multihash + +import "testing" + +func mustNotAllocateMore(t *testing.T, n float64, f func()) { + t.Helper() + if b := testing.AllocsPerRun(10, f); b > n { + t.Errorf("it allocated %f max %f !", b, n) + } +} diff --git a/multihash.go b/multihash.go index 58e631d..1ef8d92 100644 --- a/multihash.go +++ b/multihash.go @@ -27,7 +27,7 @@ var ( // ErrInconsistentLen is returned when a decoded multihash has an inconsistent length type ErrInconsistentLen struct { - dm *DecodedMultihash + dm DecodedMultihash lengthFound int } @@ -222,12 +222,26 @@ func Cast(buf []byte) (Multihash, error) { // Decode parses multihash bytes into a DecodedMultihash. func Decode(buf []byte) (*DecodedMultihash, error) { - rlen, code, hdig, err := readMultihashFromBuf(buf) + // outline decode allowing the &dm expression to be inlined into the caller. + // This moves the heap allocation into the caller and if the caller doesn't + // leak dm the compiler will use a stack allocation instead. + // If you do not outline this &dm always heap allocate since the pointer is + // returned which cause a heap allocation because Decode's stack frame is + // about to disapear. + dm, err := decode(buf) if err != nil { return nil, err } + return &dm, nil +} + +func decode(buf []byte) (dm DecodedMultihash, err error) { + rlen, code, hdig, err := readMultihashFromBuf(buf) + if err != nil { + return DecodedMultihash{}, err + } - dm := &DecodedMultihash{ + dm = DecodedMultihash{ Code: code, Name: Codes[code], Length: len(hdig), @@ -235,7 +249,7 @@ func Decode(buf []byte) (*DecodedMultihash, error) { } if len(buf) != rlen { - return nil, ErrInconsistentLen{dm, rlen} + return dm, ErrInconsistentLen{dm, rlen} } return dm, nil diff --git a/multihash_test.go b/multihash_test.go index 6230e29..4080f8e 100644 --- a/multihash_test.go +++ b/multihash_test.go @@ -151,27 +151,29 @@ func TestDecode(t *testing.T) { nb := append(pre[:n], ob...) - dec, err := Decode(nb) - if err != nil { - t.Error(err) - continue - } + mustNotAllocateMore(t, 0, func() { + dec, err := Decode(nb) + if err != nil { + t.Error(err) + return + } - if dec.Code != tc.code { - t.Error("decoded code mismatch: ", dec.Code, tc.code) - } + if dec.Code != tc.code { + t.Error("decoded code mismatch: ", dec.Code, tc.code) + } - if dec.Name != tc.name { - t.Error("decoded name mismatch: ", dec.Name, tc.name) - } + if dec.Name != tc.name { + t.Error("decoded name mismatch: ", dec.Name, tc.name) + } - if dec.Length != len(ob) { - t.Error("decoded length mismatch: ", dec.Length, len(ob)) - } + if dec.Length != len(ob) { + t.Error("decoded length mismatch: ", dec.Length, len(ob)) + } - if !bytes.Equal(dec.Digest, ob) { - t.Error("decoded byte mismatch: ", dec.Digest, ob) - } + if !bytes.Equal(dec.Digest, ob) { + t.Error("decoded byte mismatch: ", dec.Digest, ob) + } + }) } } @@ -242,15 +244,20 @@ func TestCast(t *testing.T) { nb := append(pre[:n], ob...) - if _, err := Cast(nb); err != nil { - t.Error(err) - continue - } + mustNotAllocateMore(t, 0, func() { + if _, err := Cast(nb); err != nil { + t.Error(err) + return + } + }) - if _, err = Cast(ob); err == nil { - t.Error("cast failed to detect non-multihash") - continue - } + // 1 for the error object. + mustNotAllocateMore(t, 1, func() { + if _, err = Cast(ob); err == nil { + t.Error("cast failed to detect non-multihash") + return + } + }) } } @@ -343,8 +350,29 @@ func BenchmarkDecode(b *testing.B) { pre[1] = byte(uint8(len(ob))) nb := append(pre, ob...) + b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { Decode(nb) } } + +func BenchmarkCast(b *testing.B) { + tc := testCases[0] + ob, err := hex.DecodeString(tc.hex) + if err != nil { + b.Error(err) + return + } + + pre := make([]byte, 2) + pre[0] = byte(uint8(tc.code)) + pre[1] = byte(uint8(len(ob))) + nb := append(pre, ob...) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + Cast(nb) + } +}