Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
Already on GitHub? Sign in to your account
add MultiReaderSeeker #193
Merged
Commits
Jump to file or symbol
Failed to load files and symbols.
| @@ -0,0 +1,118 @@ | ||
| +package utils | ||
| + | ||
| +import ( | ||
| + "io" | ||
| + "os" | ||
| + | ||
| + "github.com/juju/errors" | ||
| +) | ||
| + | ||
| +type multiReaderSeeker struct { | ||
| + readers []io.ReadSeeker | ||
| + sizes []int64 | ||
| + index int | ||
| + offset int64 // offset in current file. | ||
| +} | ||
| + | ||
| +// NewMultiReaderSeeker returns an io.ReadSeeker that combines | ||
| +// all the given readers into a single one. It assumes that | ||
| +// all the seekers are initially positioned at the start. | ||
| +func NewMultiReaderSeeker(readers ...io.ReadSeeker) io.ReadSeeker { | ||
| + r := &multiReaderSeeker{ | ||
| + readers: readers, | ||
| + sizes: make([]int64, len(readers)), | ||
| + } | ||
| + for i := range r.sizes { | ||
| + r.sizes[i] = -1 | ||
| + } | ||
| + return r | ||
| +} | ||
| + | ||
| +// Read implements io.Reader.Read. | ||
| +func (r *multiReaderSeeker) Read(buf []byte) (int, error) { | ||
| + if r.index >= len(r.readers) { | ||
| + return 0, io.EOF | ||
| + } | ||
| + n, err := r.readers[r.index].Read(buf) | ||
| + r.offset += int64(n) | ||
| + if err == io.EOF { | ||
| + // We've got to the end of a file so we | ||
| + // now know how big it is. | ||
| + r.sizes[r.index] = r.offset | ||
| + r.index++ | ||
| + r.offset = 0 | ||
| + err = nil | ||
| + } | ||
| + return n, err | ||
| +} | ||
| + | ||
| +// Seek implements io.Seeker.Seek. It can only be used to seek to the | ||
| +// start. | ||
| +func (r *multiReaderSeeker) Seek(offset int64, whence int) (int64, error) { | ||
| + if offset == 0 && whence == 0 { | ||
| + // Easy special case: seeking to the very start. | ||
| + for _, reader := range r.readers { | ||
| + if _, err := reader.Seek(0, 0); err != nil { | ||
| + return 0, errors.Trace(err) | ||
| + } | ||
| + } | ||
| + } | ||
| + // Find all the file sizes because we may need them. | ||
| + // Technically we could avoid some seeks here, but | ||
| + // it's probably not worth it. | ||
| + for i, size := range r.sizes { | ||
| + if size != -1 { | ||
| + continue | ||
| + } | ||
| + size, err := r.readers[i].Seek(0, 2) | ||
| + if err != nil { | ||
| + return 0, errors.Annotate(err, "cannot seek to end") | ||
| + } | ||
| + r.sizes[i] = size | ||
| + } | ||
| + switch whence { | ||
| + case os.SEEK_SET: | ||
| + // Nothing to do. | ||
| + case os.SEEK_END: | ||
| + totalSize := int64(0) | ||
| + for _, size := range r.sizes { | ||
| + totalSize += size | ||
| + } | ||
| + offset = totalSize + offset | ||
| + case os.SEEK_CUR: | ||
| + size := int64(0) | ||
| + for i := 0; i < r.index; i++ { | ||
| + size += r.sizes[i] | ||
| + } | ||
| + offset = size + r.offset + offset | ||
| + default: | ||
| + return 0, errors.New("unknown whence value in seek") | ||
| + } | ||
| + if offset < 0 { | ||
| + return 0, errors.New("negative position") | ||
| + } | ||
| + start := int64(0) | ||
| + for i, size := range r.sizes { | ||
| + if offset < start+size { | ||
| + var err error | ||
| + _, err = r.readers[i].Seek(offset-start, 0) | ||
| + if err != nil { | ||
| + return 0, errors.Annotate(err, "cannot seek into file") | ||
| + } | ||
| + // Make sure that all the subsequent readers are | ||
| + // positioned at the start. | ||
| + for _, rr := range r.readers[i+1:] { | ||
| + if _, err := rr.Seek(0, os.SEEK_SET); err != nil { | ||
| + return 0, errors.Annotate(err, "cannot seek to start of file") | ||
| + } | ||
| + } | ||
| + r.index = i | ||
| + r.offset = offset - start | ||
| + return offset, nil | ||
| + } | ||
| + start += size | ||
| + } | ||
| + r.index = len(r.readers) | ||
| + r.offset = offset - start | ||
| + return offset, nil | ||
| +} |
| @@ -0,0 +1,184 @@ | ||
| +package utils_test | ||
| + | ||
| +import ( | ||
| + "io" | ||
| + "io/ioutil" | ||
| + "strings" | ||
| + "testing/iotest" | ||
| + | ||
| + jc "github.com/juju/testing/checkers" | ||
| + "github.com/juju/utils" | ||
| + gc "gopkg.in/check.v1" | ||
| +) | ||
| + | ||
| +type multiReaderSeekerSuite struct{} | ||
| + | ||
| +var _ = gc.Suite(&multiReaderSeekerSuite{}) | ||
| + | ||
| +func (*multiReaderSeekerSuite) TestSequentialRead(c *gc.C) { | ||
| + parts := []string{ | ||
| + "one", | ||
| + "two", | ||
| + "three", | ||
| + "four", | ||
| + } | ||
| + r := newMultiStringReader(parts) | ||
| + data, err := ioutil.ReadAll(r) | ||
| + c.Assert(err, gc.IsNil) | ||
| + c.Assert(string(data), gc.Equals, strings.Join(parts, "")) | ||
| +} | ||
| + | ||
| +func (*multiReaderSeekerSuite) TestSeekStart(c *gc.C) { | ||
| + parts := []string{ | ||
| + "one", | ||
| + "two", | ||
| + "three", | ||
| + "four", | ||
| + } | ||
| + all := strings.Join(parts, "") | ||
| + for off := int64(0); off <= int64(len(all)); off++ { | ||
| + r := newMultiStringReader(parts) | ||
| + gotOff, err := r.Seek(off, 0) | ||
| + c.Assert(err, gc.IsNil) | ||
| + c.Assert(gotOff, gc.Equals, off) | ||
| + | ||
| + data, err := ioutil.ReadAll(r) | ||
| + c.Assert(err, gc.IsNil) | ||
| + c.Assert(string(data), gc.Equals, all[off:]) | ||
| + } | ||
| +} | ||
| + | ||
| +func (*multiReaderSeekerSuite) TestSeekEnd(c *gc.C) { | ||
| + parts := []string{ | ||
| + "one", | ||
| + "two", | ||
| + "three", | ||
| + "four", | ||
| + } | ||
| + all := strings.Join(parts, "") | ||
| + for off := int64(0); off <= int64(len(all)); off++ { | ||
| + r := newMultiStringReader(parts) | ||
| + expectOff := int64(len(all)) - off | ||
| + gotOff, err := r.Seek(-off, 2) | ||
| + c.Assert(err, gc.IsNil) | ||
| + c.Assert(gotOff, gc.Equals, expectOff) | ||
| + | ||
| + data, err := ioutil.ReadAll(r) | ||
| + c.Assert(err, gc.IsNil) | ||
| + c.Assert(string(data), gc.Equals, all[expectOff:]) | ||
| + } | ||
| +} | ||
| + | ||
| +func (*multiReaderSeekerSuite) TestSeekCur(c *gc.C) { | ||
| + parts := []string{ | ||
| + "one", | ||
| + "two", | ||
| + "three", | ||
| + "four", | ||
| + } | ||
| + all := strings.Join(parts, "") | ||
| + for off := int64(0); off <= int64(len(all)); off++ { | ||
| + for newOff := int64(0); newOff <= int64(len(all)); newOff++ { | ||
| + readers := make([]io.ReadSeeker, len(parts)) | ||
| + for i, part := range parts { | ||
| + readers[i] = strings.NewReader(part) | ||
| + } | ||
| + r := utils.NewMultiReaderSeeker(readers...) | ||
| + gotOff, err := r.Seek(off, 0) | ||
| + c.Assert(gotOff, gc.Equals, off) | ||
| + c.Assert(err, jc.ErrorIsNil) | ||
| + | ||
| + diff := newOff - off | ||
| + gotNewOff, err := r.Seek(diff, 1) | ||
| + c.Assert(err, gc.IsNil) | ||
| + c.Assert(gotNewOff, gc.Equals, newOff) | ||
| + | ||
| + data, err := ioutil.ReadAll(r) | ||
| + c.Assert(err, gc.IsNil) | ||
| + c.Assert(string(data), gc.Equals, all[newOff:]) | ||
| + } | ||
| + } | ||
| +} | ||
| + | ||
| +func (*multiReaderSeekerSuite) TestSeekAfterRead(c *gc.C) { | ||
| + parts := []string{ | ||
| + "one", | ||
| + "two", | ||
| + "three", | ||
| + "four", | ||
| + } | ||
| + all := strings.Join(parts, "") | ||
| + r := newMultiStringReader(parts) | ||
| + data, err := ioutil.ReadAll(iotest.OneByteReader(r)) | ||
| + c.Assert(err, gc.IsNil) | ||
| + c.Assert(string(data), gc.Equals, all) | ||
| + | ||
| + off, err := r.Seek(-8, 2) | ||
| + c.Assert(err, gc.IsNil) | ||
| + c.Assert(off, gc.Equals, int64(len(all)-8)) | ||
| + | ||
| + data, err = ioutil.ReadAll(r) | ||
| + c.Assert(err, gc.IsNil) | ||
| + c.Assert(string(data), gc.Equals, "hreefour") | ||
| +} | ||
| + | ||
| +func (*multiReaderSeekerSuite) TestSeekNegative(c *gc.C) { | ||
| + r := newMultiStringReader([]string{"one", "two"}) | ||
| + | ||
| + _, err := r.Seek(-1, 0) | ||
| + c.Assert(err, gc.ErrorMatches, "negative position") | ||
| + | ||
| + n, err := r.Seek(0, 0) | ||
| + c.Assert(err, gc.IsNil) | ||
| + c.Assert(n, gc.Equals, int64(0)) | ||
| + | ||
| + _, err = r.Seek(-7, 2) | ||
| + c.Assert(err, gc.ErrorMatches, "negative position") | ||
| + | ||
| + n, err = r.Seek(0, 0) | ||
| + c.Assert(err, gc.IsNil) | ||
| + c.Assert(n, gc.Equals, int64(0)) | ||
| + | ||
| + _, err = r.Seek(-1, 1) | ||
| + c.Assert(err, gc.ErrorMatches, "negative position") | ||
| + | ||
| + n, err = r.Seek(0, 0) | ||
| + c.Assert(err, gc.IsNil) | ||
| + c.Assert(n, gc.Equals, int64(0)) | ||
| +} | ||
| + | ||
| +func (*multiReaderSeekerSuite) TestSeekPastEnd(c *gc.C) { | ||
| + r := newMultiStringReader([]string{"one", "two"}) | ||
| + | ||
| + n, err := r.Seek(100, 0) | ||
| + c.Assert(err, jc.ErrorIsNil) | ||
| + c.Assert(n, gc.Equals, int64(100)) | ||
| + | ||
| + nr, err := r.Read(make([]byte, 1)) | ||
| + c.Assert(nr, gc.Equals, 0) | ||
| + c.Assert(err, gc.Equals, io.EOF) | ||
| + | ||
| + n, err = r.Seek(-5, 1) | ||
| + c.Assert(err, jc.ErrorIsNil) | ||
| + c.Assert(n, gc.Equals, int64(95)) | ||
| + | ||
| + nr, err = r.Read(make([]byte, 1)) | ||
| + c.Assert(nr, gc.Equals, 0) | ||
| + c.Assert(err, gc.Equals, io.EOF) | ||
| + | ||
| + n, err = r.Seek(-94, 1) | ||
| + c.Assert(err, jc.ErrorIsNil) | ||
| + c.Assert(n, gc.Equals, int64(1)) | ||
| + | ||
| + data, err := ioutil.ReadAll(r) | ||
| + c.Assert(err, gc.IsNil) | ||
| + c.Assert(string(data), gc.Equals, "netwo") | ||
| +} | ||
| + | ||
| +func newMultiStringReader(parts []string) io.ReadSeeker { | ||
| + readers := make([]io.ReadSeeker, len(parts)) | ||
| + for i, part := range parts { | ||
| + readers[i] = strings.NewReader(part) | ||
| + } | ||
| + return utils.NewMultiReaderSeeker(readers...) | ||
| +} |