diff --git a/cmd/copy.go b/cmd/copy.go index 8a6cce1..203646b 100644 --- a/cmd/copy.go +++ b/cmd/copy.go @@ -135,7 +135,7 @@ func setOptions(c *cli.Context, supervisor *copier.Supervisor) error { supervisor.ExecTimeout = timeout } - if r := c.Int("retries"); r > 0 { + if r := c.Int("retries"); r >= 0 { supervisor.Retries = r } diff --git a/exec.go b/exec.go index e07edf6..ac085f9 100644 --- a/exec.go +++ b/exec.go @@ -2,6 +2,7 @@ package copier import ( "context" + "fmt" "io" "os" @@ -11,9 +12,12 @@ import ( // An Exec copies files at a given ratelimit. type Exec struct { + opened bool size int64 ctx context.Context - reader *ProxyReader + r io.ReadCloser // input file + w io.WriteCloser // output file + pr *ProxyReader From string To string Speed float64 @@ -62,26 +66,29 @@ func (e *Exec) Execute() error { e.status = StatusCopied } - r, err := os.Open(e.From) + e.opened = true + + e.r, err = os.Open(e.From) if err != nil { return errors.Annotate(err, "source") } - defer r.Close() + defer e.r.Close() w, err := os.Create(e.To) if err != nil { return errors.Annotate(err, "destination") } defer w.Close() + e.w = w - rr := shapeio.NewReaderWithContext(r, e.ctx) + rr := shapeio.NewReaderWithContext(e.r, e.ctx) rr.SetRateLimit(e.Speed) - e.reader = NewProxyReader(rr) // Used for progressbar - defer e.reader.Close() + e.pr = NewProxyReader(rr) // Used for progressbar + defer e.pr.Close() e.Ready <- struct{}{} - if _, err = io.Copy(w, e.reader); err != nil { + if _, err = io.Copy(e.w, e.pr); err != nil { return errors.Annotate(err, "copy") } @@ -105,5 +112,23 @@ func (e *Exec) Size() int64 { // Reader returns the file reader. func (e *Exec) Reader() *ProxyReader { - return e.reader + return e.pr +} + +// ForceClose force closes all its IO objects. +func (e *Exec) ForceClose() { + if !e.opened { + return + } + + e.asyncClose(e.pr) + e.asyncClose(e.r) + e.asyncClose(e.w) + fmt.Println("Forced close.") +} + +func (e *Exec) asyncClose(c io.Closer) { + if c != nil { + go c.Close() + } } diff --git a/supervisor.go b/supervisor.go index 7f214c5..4c57834 100644 --- a/supervisor.go +++ b/supervisor.go @@ -92,18 +92,17 @@ func (s *Supervisor) Execute() error { return nil } -func (s *Supervisor) execute(from, to string, retries int, err2 error) error { - if retries == 0 { +func (s *Supervisor) execute(from, to string, retries int, err2 error) (err error) { + if retries < 0 { return err2 // keep last error } errc := make(chan error) - var err error ctx, cancel := context.WithTimeout(s.Context, s.ExecTimeout) defer cancel() - cp := NewExecWithContext(ctx, from, to, s.Speed) + cp := NewExec(from, to, s.Speed) cp.Speed = s.Speed go func() { if err = cp.Execute(); err != nil { @@ -126,6 +125,7 @@ func (s *Supervisor) execute(from, to string, retries int, err2 error) error { ready = true case <-ctx.Done(): if ctx.Err() != nil { + cp.ForceClose() time.Sleep(s.RetryInterval) err = s.execute(from, to, retries-1, ctx.Err()) } @@ -133,6 +133,7 @@ func (s *Supervisor) execute(from, to string, retries int, err2 error) error { terminated = true case err = <-errc: if err != nil { + cp.ForceClose() time.Sleep(s.RetryInterval) err = s.execute(from, to, retries-1, err) } diff --git a/utils.go b/utils.go index 177aa13..5582cbe 100644 --- a/utils.go +++ b/utils.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" "strings" + "sync" ) const ( @@ -135,8 +136,10 @@ func MoreRecent(dst, src string) (bool, error) { // A ProxyReader is proxified reader. type ProxyReader struct { + m sync.Mutex reader io.Reader writer io.Writer + closed bool c chan int } @@ -151,19 +154,39 @@ func NewProxyReader(r io.Reader) *ProxyReader { // Read implements io.Reader func (r *ProxyReader) Read(b []byte) (n int, err error) { n, err = r.reader.Read(b) - r.c <- n + + if !r.Closed() { + r.c <- n + } return } // Close the reader when it implements io.Closer func (r *ProxyReader) Close() (err error) { - if closer, ok := r.reader.(io.Closer); ok { - return closer.Close() + if r.Closed() { + return nil + } + + if r.reader != nil { + if closer, ok := r.reader.(io.Closer); ok { + return closer.Close() + } } close(r.c) + + r.m.Lock() + defer r.m.Unlock() + r.closed = true return } +// Closed returns the reader status. +func (r *ProxyReader) Closed() bool { + r.m.Lock() + defer r.m.Unlock() + return r.closed +} + // ReadChunk returns the number of read bytes for a chunk. func (r *ProxyReader) ReadChunk() <-chan int { return r.c