From d83b0f6a1d9794300a24e283f92e2b62fe404975 Mon Sep 17 00:00:00 2001 From: Hippolyte Barraud Date: Tue, 27 Jun 2023 14:35:44 -0400 Subject: [PATCH] blob: pass through reader/writer to `WriteTo`/`ReadFrom` if available MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The blob `Reader` and `Writer` implement `WriterTo` and `ReaderFrom`, respectively, which is meant to avoid intermediary allocations and copies (passing readers and readers through all the way down). Unfortunately, the current implementation falls short and implements those interfaces by making a local copy ala io.Copy. Instead, reproduce the strategy in newer versions of `io.Copy`: reflect on the reader/writer provided by the user, and if they implement these interface, too, then call into that. Local benchmarks show significant performance improvements: ``` │ /tmp/old │ /tmp/new │ │ sec/op │ sec/op vs base │ Memblob/BenchmarkRead-10 3.449µ ± 15% 1.919µ ± 2% -44.37% (p=0.000 n=10) ``` --- blob/blob.go | 12 ++++++++++++ blob/blob_reader_test.go | 25 ++++++++++++++++++++++--- blob/drivertest/drivertest.go | 13 +++++++++++-- 3 files changed, 45 insertions(+), 5 deletions(-) diff --git a/blob/blob.go b/blob/blob.go index f530e48443..729c34fe4f 100644 --- a/blob/blob.go +++ b/blob/blob.go @@ -233,6 +233,12 @@ func (r *Reader) As(i interface{}) bool { // // It implements the io.WriterTo interface. func (r *Reader) WriteTo(w io.Writer) (int64, error) { + // If the writer has a ReaderFrom method, use it to do the copy. + // Avoids an allocation and a copy. + if rt, ok := w.(io.ReaderFrom); ok { + return rt.ReadFrom(r) + } + _, nw, err := readFromWriteTo(r, w) return nw, err } @@ -476,6 +482,12 @@ func (w *Writer) write(p []byte) (int, error) { // // It implements the io.ReaderFrom interface. func (w *Writer) ReadFrom(r io.Reader) (int64, error) { + // If the reader has a WriteTo method, use it to do the copy. + // Avoids an allocation and a copy. + if wt, ok := r.(io.WriterTo); ok { + return wt.WriteTo(w) + } + nr, _, err := readFromWriteTo(r, w) return nr, err } diff --git a/blob/blob_reader_test.go b/blob/blob_reader_test.go index 0db2621d8b..6fed605eb3 100644 --- a/blob/blob_reader_test.go +++ b/blob/blob_reader_test.go @@ -15,7 +15,9 @@ package blob_test import ( + "bytes" "context" + "io" "testing" "testing/iotest" @@ -41,12 +43,29 @@ func TestReader(t *testing.T) { bucket.WriteAll(ctx, myKey, data, nil) // Create a blob.Reader. - r, err := bucket.NewReader(ctx, myKey, nil) + r1, err := bucket.NewReader(ctx, myKey, nil) if err != nil { t.Fatal(err) } - defer r.Close() - if err := iotest.TestReader(r, data); err != nil { + r1.Close() + if err := iotest.TestReader(r1, data); err != nil { t.Error(err) } + + // Create another blob.Reader to exercise the ReadFrom code path + r2, err := bucket.NewReader(ctx, myKey, nil) + if err != nil { + t.Fatal(err) + } + defer r2.Close() + + var buffer bytes.Buffer + n, err := io.Copy(&buffer, r2) + if err != nil { + t.Fatal(err) + } else if n != int64(len(data)) { + t.Fatal("wrote fewer bytes than expected") + } else if !bytes.Equal(buffer.Bytes(), data) { + t.Fatal("wrote invalid bytes") + } } diff --git a/blob/drivertest/drivertest.go b/blob/drivertest/drivertest.go index 6c10034587..d774868d69 100644 --- a/blob/drivertest/drivertest.go +++ b/blob/drivertest/drivertest.go @@ -2624,12 +2624,21 @@ func benchmarkRead(b *testing.B, bkt *blob.Bucket) { b.ResetTimer() b.RunParallel(func(pb *testing.PB) { + var buffer bytes.Buffer + buffer.Grow(len(content)) + for pb.Next() { - buf, err := bkt.ReadAll(ctx, key) + buffer.Reset() + r, err := bkt.NewReader(ctx, key, nil) if err != nil { b.Error(err) } - if !bytes.Equal(buf, content) { + + if _, err = io.Copy(&buffer, r); err != nil { + b.Error(err) + } + r.Close() + if !bytes.Equal(buffer.Bytes(), content) { b.Error("read didn't match write") } }