add MultiReaderSeeker #193

Merged
merged 1 commit into from Jan 28, 2016
Jump to file or symbol
Failed to load files and symbols.
+302 −0
Split
View
@@ -0,0 +1,118 @@
+package utils
+
+import (
+ "io"
+ "os"
+
+ "github.com/juju/errors"
+)
+
+type multiReaderSeeker struct {
+ readers []io.ReadSeeker
+ sizes []int64
+ index int
+ offset int64 // offset in current file.
+}
+
+// NewMultiReaderSeeker returns an io.ReadSeeker that combines
+// all the given readers into a single one. It assumes that
+// all the seekers are initially positioned at the start.
+func NewMultiReaderSeeker(readers ...io.ReadSeeker) io.ReadSeeker {
+ r := &multiReaderSeeker{
+ readers: readers,
+ sizes: make([]int64, len(readers)),
+ }
+ for i := range r.sizes {
+ r.sizes[i] = -1
+ }
+ return r
+}
+
+// Read implements io.Reader.Read.
+func (r *multiReaderSeeker) Read(buf []byte) (int, error) {
+ if r.index >= len(r.readers) {
+ return 0, io.EOF
+ }
+ n, err := r.readers[r.index].Read(buf)
+ r.offset += int64(n)
+ if err == io.EOF {
+ // We've got to the end of a file so we
+ // now know how big it is.
+ r.sizes[r.index] = r.offset
+ r.index++
+ r.offset = 0
+ err = nil
+ }
+ return n, err
+}
+
+// Seek implements io.Seeker.Seek. It can only be used to seek to the
+// start.
+func (r *multiReaderSeeker) Seek(offset int64, whence int) (int64, error) {
+ if offset == 0 && whence == 0 {
+ // Easy special case: seeking to the very start.
+ for _, reader := range r.readers {
+ if _, err := reader.Seek(0, 0); err != nil {
+ return 0, errors.Trace(err)
+ }
+ }
+ }
+ // Find all the file sizes because we may need them.
+ // Technically we could avoid some seeks here, but
+ // it's probably not worth it.
+ for i, size := range r.sizes {
+ if size != -1 {
+ continue
+ }
+ size, err := r.readers[i].Seek(0, 2)
+ if err != nil {
+ return 0, errors.Annotate(err, "cannot seek to end")
+ }
+ r.sizes[i] = size
+ }
+ switch whence {
+ case os.SEEK_SET:
+ // Nothing to do.
+ case os.SEEK_END:
+ totalSize := int64(0)
+ for _, size := range r.sizes {
+ totalSize += size
+ }
+ offset = totalSize + offset
+ case os.SEEK_CUR:
+ size := int64(0)
+ for i := 0; i < r.index; i++ {
+ size += r.sizes[i]
+ }
+ offset = size + r.offset + offset
+ default:
+ return 0, errors.New("unknown whence value in seek")
+ }
+ if offset < 0 {
+ return 0, errors.New("negative position")
+ }
+ start := int64(0)
+ for i, size := range r.sizes {
+ if offset < start+size {
+ var err error
+ _, err = r.readers[i].Seek(offset-start, 0)
+ if err != nil {
+ return 0, errors.Annotate(err, "cannot seek into file")
+ }
+ // Make sure that all the subsequent readers are
+ // positioned at the start.
+ for _, rr := range r.readers[i+1:] {
+ if _, err := rr.Seek(0, os.SEEK_SET); err != nil {
+ return 0, errors.Annotate(err, "cannot seek to start of file")
+ }
+ }
+ r.index = i
+ r.offset = offset - start
+ return offset, nil
+ }
+ start += size
+ }
+ r.index = len(r.readers)
+ r.offset = offset - start
+ return offset, nil
+}
View
@@ -0,0 +1,184 @@
+package utils_test
+
+import (
+ "io"
+ "io/ioutil"
+ "strings"
+ "testing/iotest"
+
+ jc "github.com/juju/testing/checkers"
+ "github.com/juju/utils"
+ gc "gopkg.in/check.v1"
+)
+
+type multiReaderSeekerSuite struct{}
+
+var _ = gc.Suite(&multiReaderSeekerSuite{})
+
+func (*multiReaderSeekerSuite) TestSequentialRead(c *gc.C) {
+ parts := []string{
+ "one",
+ "two",
+ "three",
+ "four",
+ }
+ r := newMultiStringReader(parts)
+ data, err := ioutil.ReadAll(r)
+ c.Assert(err, gc.IsNil)
+ c.Assert(string(data), gc.Equals, strings.Join(parts, ""))
+}
+
+func (*multiReaderSeekerSuite) TestSeekStart(c *gc.C) {
+ parts := []string{
+ "one",
+ "two",
+ "three",
+ "four",
+ }
+ all := strings.Join(parts, "")
+ for off := int64(0); off <= int64(len(all)); off++ {
+ r := newMultiStringReader(parts)
+ gotOff, err := r.Seek(off, 0)
+ c.Assert(err, gc.IsNil)
+ c.Assert(gotOff, gc.Equals, off)
+
+ data, err := ioutil.ReadAll(r)
+ c.Assert(err, gc.IsNil)
+ c.Assert(string(data), gc.Equals, all[off:])
+ }
+}
+
+func (*multiReaderSeekerSuite) TestSeekEnd(c *gc.C) {
+ parts := []string{
+ "one",
+ "two",
+ "three",
+ "four",
+ }
+ all := strings.Join(parts, "")
+ for off := int64(0); off <= int64(len(all)); off++ {
+ r := newMultiStringReader(parts)
+ expectOff := int64(len(all)) - off
+ gotOff, err := r.Seek(-off, 2)
+ c.Assert(err, gc.IsNil)
+ c.Assert(gotOff, gc.Equals, expectOff)
+
+ data, err := ioutil.ReadAll(r)
+ c.Assert(err, gc.IsNil)
+ c.Assert(string(data), gc.Equals, all[expectOff:])
+ }
+}
+
+func (*multiReaderSeekerSuite) TestSeekCur(c *gc.C) {
+ parts := []string{
+ "one",
+ "two",
+ "three",
+ "four",
+ }
+ all := strings.Join(parts, "")
+ for off := int64(0); off <= int64(len(all)); off++ {
+ for newOff := int64(0); newOff <= int64(len(all)); newOff++ {
+ readers := make([]io.ReadSeeker, len(parts))
+ for i, part := range parts {
+ readers[i] = strings.NewReader(part)
+ }
+ r := utils.NewMultiReaderSeeker(readers...)
+ gotOff, err := r.Seek(off, 0)
+ c.Assert(gotOff, gc.Equals, off)
+ c.Assert(err, jc.ErrorIsNil)
+
+ diff := newOff - off
+ gotNewOff, err := r.Seek(diff, 1)
+ c.Assert(err, gc.IsNil)
+ c.Assert(gotNewOff, gc.Equals, newOff)
+
+ data, err := ioutil.ReadAll(r)
+ c.Assert(err, gc.IsNil)
+ c.Assert(string(data), gc.Equals, all[newOff:])
+ }
+ }
+}
+
+func (*multiReaderSeekerSuite) TestSeekAfterRead(c *gc.C) {
+ parts := []string{
+ "one",
+ "two",
+ "three",
+ "four",
+ }
+ all := strings.Join(parts, "")
+ r := newMultiStringReader(parts)
+ data, err := ioutil.ReadAll(iotest.OneByteReader(r))
+ c.Assert(err, gc.IsNil)
+ c.Assert(string(data), gc.Equals, all)
+
+ off, err := r.Seek(-8, 2)
+ c.Assert(err, gc.IsNil)
+ c.Assert(off, gc.Equals, int64(len(all)-8))
+
+ data, err = ioutil.ReadAll(r)
+ c.Assert(err, gc.IsNil)
+ c.Assert(string(data), gc.Equals, "hreefour")
+}
+
+func (*multiReaderSeekerSuite) TestSeekNegative(c *gc.C) {
+ r := newMultiStringReader([]string{"one", "two"})
+
+ _, err := r.Seek(-1, 0)
+ c.Assert(err, gc.ErrorMatches, "negative position")
+
+ n, err := r.Seek(0, 0)
+ c.Assert(err, gc.IsNil)
+ c.Assert(n, gc.Equals, int64(0))
+
+ _, err = r.Seek(-7, 2)
+ c.Assert(err, gc.ErrorMatches, "negative position")
+
+ n, err = r.Seek(0, 0)
+ c.Assert(err, gc.IsNil)
+ c.Assert(n, gc.Equals, int64(0))
+
+ _, err = r.Seek(-1, 1)
+ c.Assert(err, gc.ErrorMatches, "negative position")
+
+ n, err = r.Seek(0, 0)
+ c.Assert(err, gc.IsNil)
+ c.Assert(n, gc.Equals, int64(0))
+}
+
+func (*multiReaderSeekerSuite) TestSeekPastEnd(c *gc.C) {
+ r := newMultiStringReader([]string{"one", "two"})
+
+ n, err := r.Seek(100, 0)
+ c.Assert(err, jc.ErrorIsNil)
+ c.Assert(n, gc.Equals, int64(100))
+
+ nr, err := r.Read(make([]byte, 1))
+ c.Assert(nr, gc.Equals, 0)
+ c.Assert(err, gc.Equals, io.EOF)
+
+ n, err = r.Seek(-5, 1)
+ c.Assert(err, jc.ErrorIsNil)
+ c.Assert(n, gc.Equals, int64(95))
+
+ nr, err = r.Read(make([]byte, 1))
+ c.Assert(nr, gc.Equals, 0)
+ c.Assert(err, gc.Equals, io.EOF)
+
+ n, err = r.Seek(-94, 1)
+ c.Assert(err, jc.ErrorIsNil)
+ c.Assert(n, gc.Equals, int64(1))
+
+ data, err := ioutil.ReadAll(r)
+ c.Assert(err, gc.IsNil)
+ c.Assert(string(data), gc.Equals, "netwo")
+}
+
+func newMultiStringReader(parts []string) io.ReadSeeker {
+ readers := make([]io.ReadSeeker, len(parts))
+ for i, part := range parts {
+ readers[i] = strings.NewReader(part)
+ }
+ return utils.NewMultiReaderSeeker(readers...)
+}