diff --git a/progressutil/iocopy.go b/progressutil/iocopy.go index d0572de..0ad5cb6 100644 --- a/progressutil/iocopy.go +++ b/progressutil/iocopy.go @@ -15,32 +15,22 @@ package progressutil import ( + "errors" "fmt" "io" "sync" "time" ) -type copyReader struct { - reader io.Reader - current int64 - total int64 - done bool - doneLock sync.Mutex - pb *ProgressBar -} - -func (cr *copyReader) getDone() bool { - cr.doneLock.Lock() - val := cr.done - cr.doneLock.Unlock() - return val -} +var ( + ErrAlreadyStarted = errors.New("cannot add copies after PrintAndWait has been called") +) -func (cr *copyReader) setDone(val bool) { - cr.doneLock.Lock() - cr.done = val - cr.doneLock.Unlock() +type copyReader struct { + reader io.Reader + current int64 + total int64 + pb *ProgressBar } func (cr *copyReader) Read(p []byte) (int, error) { @@ -50,9 +40,6 @@ func (cr *copyReader) Read(p []byte) (int, error) { if err == nil { err = err1 } - if err != nil { - cr.setDone(true) - } return n, err } @@ -66,12 +53,19 @@ func (cr *copyReader) updateProgressBar() error { return cr.pb.SetCurrentProgress(progress) } +// NewCopyProgressPrinter returns a new CopyProgressPrinter +func NewCopyProgressPrinter() *CopyProgressPrinter { + return &CopyProgressPrinter{results: make(chan error)} +} + // CopyProgressPrinter will perform an arbitrary number of io.Copy calls, while // continually printing the progress of each copy. type CopyProgressPrinter struct { - readers []*copyReader - errors []error + results chan error + lock sync.Mutex + readers []*copyReader + started bool pbp *ProgressBarPrinter } @@ -79,8 +73,15 @@ type CopyProgressPrinter struct { // will be made to copy bytes from reader to dest, and name and size will be // used to label the progress bar and display how much progress has been made. // If size is 0, the total size of the reader is assumed to be unknown. -func (cpp *CopyProgressPrinter) AddCopy(reader io.Reader, name string, size int64, dest io.Writer) { +// AddCopy can only be called before PrintAndWait; otherwise, ErrAlreadyStarted +// will be returned. +func (cpp *CopyProgressPrinter) AddCopy(reader io.Reader, name string, size int64, dest io.Writer) error { cpp.lock.Lock() + defer cpp.lock.Unlock() + + if cpp.started { + return ErrAlreadyStarted + } if cpp.pbp == nil { cpp.pbp = &ProgressBarPrinter{} cpp.pbp.PadToBeEven = true @@ -96,60 +97,56 @@ func (cpp *CopyProgressPrinter) AddCopy(reader io.Reader, name string, size int6 cr.pb.SetPrintAfter(cr.formattedProgress()) cpp.readers = append(cpp.readers, cr) - cpp.lock.Unlock() go func() { _, err := io.Copy(dest, cr) - if err != nil { - cpp.lock.Lock() - cpp.errors = append(cpp.errors, err) - cpp.lock.Unlock() - } + cpp.results <- err }() + return nil } // PrintAndWait will print the progress for each copy operation added with // AddCopy to printTo every printInterval. This will continue until every added // copy is finished, or until cancel is written to. +// PrintAndWait may only be called once; any subsequent calls will immediately +// return ErrAlreadyStarted. After PrintAndWait has been called, no more +// copies may be added to the CopyProgressPrinter. func (cpp *CopyProgressPrinter) PrintAndWait(printTo io.Writer, printInterval time.Duration, cancel chan struct{}) error { - for { - // If cancel is not nil, see if anything has been written to it. If - // something has, return, otherwise keep drawing. - if cancel != nil { - select { - case <-cancel: - return nil - default: - } - } - - cpp.lock.Lock() - readers := cpp.readers - errors := cpp.errors + cpp.lock.Lock() + if cpp.started { cpp.lock.Unlock() + return ErrAlreadyStarted + } + cpp.started = true + cpp.lock.Unlock() - if len(errors) > 0 { - return errors[0] - } + n := len(cpp.readers) + if n == 0 { + // Nothing to do. + return nil + } - if len(readers) > 0 { + t := time.NewTicker(printInterval) + for i := 0; i < n; { + select { + case <-cancel: + return nil + case <-t.C: _, err := cpp.pbp.Print(printTo) if err != nil { return err } - } else { - } - - allDone := true - for _, r := range readers { - allDone = allDone && r.getDone() - } - if allDone && len(readers) > 0 { - return nil + case err := <-cpp.results: + i++ + if err == nil { + _, err = cpp.pbp.Print(printTo) + } + if err != nil { + return err + } } - - time.Sleep(printInterval) } + return nil } func (cr *copyReader) formattedProgress() string { diff --git a/progressutil/progressbar.go b/progressutil/progressbar.go index 31c6247..224c124 100644 --- a/progressutil/progressbar.go +++ b/progressutil/progressbar.go @@ -191,14 +191,14 @@ func (pbp *ProgressBarPrinter) Print(printTo io.Writer) (bool, error) { } } - allDone := false + allDone := true for _, bar := range bars { if isTerminal(printTo) { bar.printToTerminal(printTo, numColumns, pbp.PadToBeEven, pbp.maxBefore, pbp.maxAfter) } else { bar.printToNonTerminal(printTo) } - allDone = allDone || bar.GetCurrentProgress() == 1 + allDone = allDone && bar.GetCurrentProgress() == 1 } pbp.numLinesInLastPrint = len(bars)