diff --git a/blob/blob.go b/blob/blob.go index f530e48443..729c34fe4f 100644 --- a/blob/blob.go +++ b/blob/blob.go @@ -233,6 +233,12 @@ func (r *Reader) As(i interface{}) bool { // // It implements the io.WriterTo interface. func (r *Reader) WriteTo(w io.Writer) (int64, error) { + // If the writer has a ReaderFrom method, use it to do the copy. + // Avoids an allocation and a copy. + if rt, ok := w.(io.ReaderFrom); ok { + return rt.ReadFrom(r) + } + _, nw, err := readFromWriteTo(r, w) return nw, err } @@ -476,6 +482,12 @@ func (w *Writer) write(p []byte) (int, error) { // // It implements the io.ReaderFrom interface. func (w *Writer) ReadFrom(r io.Reader) (int64, error) { + // If the reader has a WriteTo method, use it to do the copy. + // Avoids an allocation and a copy. + if wt, ok := r.(io.WriterTo); ok { + return wt.WriteTo(w) + } + nr, _, err := readFromWriteTo(r, w) return nr, err } diff --git a/blob/blob_reader_test.go b/blob/blob_reader_test.go index 0db2621d8b..6fed605eb3 100644 --- a/blob/blob_reader_test.go +++ b/blob/blob_reader_test.go @@ -15,7 +15,9 @@ package blob_test import ( + "bytes" "context" + "io" "testing" "testing/iotest" @@ -41,12 +43,29 @@ func TestReader(t *testing.T) { bucket.WriteAll(ctx, myKey, data, nil) // Create a blob.Reader. - r, err := bucket.NewReader(ctx, myKey, nil) + r1, err := bucket.NewReader(ctx, myKey, nil) if err != nil { t.Fatal(err) } - defer r.Close() - if err := iotest.TestReader(r, data); err != nil { + r1.Close() + if err := iotest.TestReader(r1, data); err != nil { t.Error(err) } + + // Create another blob.Reader to exercise the ReadFrom code path + r2, err := bucket.NewReader(ctx, myKey, nil) + if err != nil { + t.Fatal(err) + } + defer r2.Close() + + var buffer bytes.Buffer + n, err := io.Copy(&buffer, r2) + if err != nil { + t.Fatal(err) + } else if n != int64(len(data)) { + t.Fatal("wrote fewer bytes than expected") + } else if !bytes.Equal(buffer.Bytes(), data) { + t.Fatal("wrote invalid bytes") + } } diff --git a/blob/drivertest/drivertest.go b/blob/drivertest/drivertest.go index 6c10034587..d774868d69 100644 --- a/blob/drivertest/drivertest.go +++ b/blob/drivertest/drivertest.go @@ -2624,12 +2624,21 @@ func benchmarkRead(b *testing.B, bkt *blob.Bucket) { b.ResetTimer() b.RunParallel(func(pb *testing.PB) { + var buffer bytes.Buffer + buffer.Grow(len(content)) + for pb.Next() { - buf, err := bkt.ReadAll(ctx, key) + buffer.Reset() + r, err := bkt.NewReader(ctx, key, nil) if err != nil { b.Error(err) } - if !bytes.Equal(buf, content) { + + if _, err = io.Copy(&buffer, r); err != nil { + b.Error(err) + } + r.Close() + if !bytes.Equal(buffer.Bytes(), content) { b.Error("read didn't match write") } }