Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean up uniter downloads #6325

Merged
merged 7 commits into from Sep 27, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
180 changes: 83 additions & 97 deletions downloader/download.go
Expand Up @@ -11,7 +11,6 @@ import (

"github.com/juju/errors"
"github.com/juju/utils"
"gopkg.in/tomb.v1"
)

// Request holds a single download request.
Expand All @@ -27,125 +26,88 @@ type Request struct {
// the download is invalid then the func must return errors.NotValid.
// If no func is provided then no verification happens.
Verify func(*os.File) error

// Abort is a channel that will cancel the download when it is closed.
Abort <-chan struct{}
}

// Status represents the status of a completed download.
type Status struct {
// File holds the downloaded data on success.
File *os.File
// Filename is the name of the file which holds the downloaded
// data on success.
Filename string

// Err describes any error encountered while downloading.
Err error
}

// Download can download a file from the network.
type Download struct {
tomb tomb.Tomb
done chan Status
openBlob func(*url.URL) (io.ReadCloser, error)
}

// StartDownload returns a new Download instance based on the provided
// request. openBlob is used to gain access to the blob, whether through
// an HTTP request or some other means.
// StartDownload starts a new download as specified by `req` using
// `openBlob` to actually pull the remote data.
func StartDownload(req Request, openBlob func(*url.URL) (io.ReadCloser, error)) *Download {
dl := newDownload(openBlob)
go dl.run(req)
return dl
}

func newDownload(openBlob func(*url.URL) (io.ReadCloser, error)) *Download {
if openBlob == nil {
openBlob = NewHTTPBlobOpener(utils.NoVerifySSLHostnames)
}
return &Download{
done: make(chan Status),
dl := &Download{
done: make(chan Status, 1),
openBlob: openBlob,
}
go dl.run(req)
return dl
}

// Stop stops any download that's in progress.
func (dl *Download) Stop() {
dl.tomb.Kill(nil)
dl.tomb.Wait()
// Download can download a file from the network.
type Download struct {
done chan Status
openBlob func(*url.URL) (io.ReadCloser, error)
}

// Done returns a channel that receives a status when the download has
// completed. It is the receiver's responsibility to close and remove
// the received file.
// completed or is aborted. Exactly one Status value will be sent for
// each download once it finishes (successfully or otherwise) or is
// aborted.
//
// It is the receiver's responsibility to handle and remove the
// downloaded file.
func (dl *Download) Done() <-chan Status {
return dl.done
}

// Wait blocks until the download completes or the abort channel receives.
func (dl *Download) Wait(abort <-chan struct{}) (*os.File, error) {
defer dl.Stop()

select {
case <-abort:
logger.Infof("download aborted")
return nil, errors.New("aborted")
case status := <-dl.Done():
if status.Err != nil {
if status.File != nil {
if err := status.File.Close(); err != nil {
logger.Errorf("failed to close file: %v", err)
}
}
return nil, errors.Trace(status.Err)
}
return status.File, nil
}
// Wait blocks until the download finishes (successfully or
// otherwise), or the download is aborted. There will only be a
// filename if err is nil.
func (dl *Download) Wait() (string, error) {
// No select required here because each download will always
// return a value once it completes. Downloads can be aborted via
// the Abort channel provided a creation time.
status := <-dl.Done()
return status.Filename, errors.Trace(status.Err)
}

func (dl *Download) run(req Request) {
defer dl.tomb.Done()

// TODO(dimitern) 2013-10-03 bug #1234715
// Add a testing HTTPS storage to verify the
// disableSSLHostnameVerification behavior here.
file, err := download(req, dl.openBlob)
filename, err := dl.download(req)
if err != nil {
err = errors.Annotatef(err, "cannot download %q", req.URL)
}

if err == nil {
err = errors.Trace(err)
} else {
logger.Infof("download complete (%q)", req.URL)
if req.Verify != nil {
err = verifyDownload(file, req)
}
}

status := Status{
File: file,
Err: err,
}
select {
case dl.done <- status:
// no-op
case <-dl.tomb.Dying():
cleanTempFile(file)
}
}

func verifyDownload(file *os.File, req Request) error {
err := req.Verify(file)
if err != nil {
if errors.IsNotValid(err) {
logger.Errorf("download of %s invalid: %v", req.URL, err)
err = verifyDownload(filename, req)
if err != nil {
os.Remove(filename)
filename = ""
}
return errors.Trace(err)
}
logger.Infof("download verified (%q)", req.URL)

if _, err := file.Seek(0, os.SEEK_SET); err != nil {
logger.Errorf("failed to seek to beginning of file: %v", err)
return errors.Trace(err)
// No select needed here because the channel has a size of 1 and
// will only be written to once.
dl.done <- Status{
Filename: filename,
Err: err,
}
return nil
}

func download(req Request, openBlob func(*url.URL) (io.ReadCloser, error)) (file *os.File, err error) {
func (dl *Download) download(req Request) (filename string, err error) {
logger.Infof("downloading from %s", req.URL)

dir := req.TargetDir
Expand All @@ -154,37 +116,61 @@ func download(req Request, openBlob func(*url.URL) (io.ReadCloser, error)) (file
}
tempFile, err := ioutil.TempFile(dir, "inprogress-")
if err != nil {
return nil, errors.Trace(err)
return "", errors.Trace(err)
}
defer func() {
tempFile.Close()
if err != nil {
cleanTempFile(tempFile)
os.Remove(tempFile.Name())
}
}()

reader, err := openBlob(req.URL)
blobReader, err := dl.openBlob(req.URL)
if err != nil {
return nil, errors.Trace(err)
return "", errors.Trace(err)
}
defer reader.Close()
defer blobReader.Close()

reader := &abortableReader{blobReader, req.Abort}
_, err = io.Copy(tempFile, reader)
if err != nil {
return nil, errors.Trace(err)
return "", errors.Trace(err)
}
if _, err := tempFile.Seek(0, 0); err != nil {
return nil, errors.Trace(err)

return tempFile.Name(), nil
}

// abortableReader wraps a Reader, returning an error from Read calls
// if the abort channel provided is closed.
type abortableReader struct {
r io.Reader
abort <-chan struct{}
}

// Read implements io.Reader.
func (ar *abortableReader) Read(p []byte) (int, error) {
select {
case <-ar.abort:
return 0, errors.New("download aborted")
default:
}
return tempFile, nil
return ar.r.Read(p)
}

func cleanTempFile(f *os.File) {
if f == nil {
return
func verifyDownload(filename string, req Request) error {
if req.Verify == nil {
return nil
}

f.Close()
if err := os.Remove(f.Name()); err != nil {
logger.Errorf("cannot remove temp file %q: %v", f.Name(), err)
file, err := os.Open(filename)
if err != nil {
return errors.Annotate(err, "opening for verify")
}
defer file.Close()

if err := req.Verify(file); err != nil {
return errors.Trace(err)
}
logger.Infof("download verified (%q)", req.URL)
return nil
}
85 changes: 42 additions & 43 deletions downloader/download_test.go
Expand Up @@ -8,7 +8,6 @@ import (
"net/url"
"os"
"path/filepath"
"time"

"github.com/juju/errors"
gitjujutesting "github.com/juju/testing"
Expand Down Expand Up @@ -65,13 +64,11 @@ func (s *DownloadSuite) testDownload(c *gc.C, hostnameVerification utils.SSLHost
downloader.NewHTTPBlobOpener(hostnameVerification),
)
status := <-d.Done()
defer status.File.Close()
c.Assert(status.Err, gc.IsNil)
c.Assert(status.File, gc.NotNil)

dir, _ := filepath.Split(status.File.Name())
dir, _ := filepath.Split(status.Filename)
c.Assert(filepath.Clean(dir), gc.Equals, tmp)
assertFileContents(c, status.File, "archive")
assertFileContents(c, status.Filename, "archive")
}

func (s *DownloadSuite) TestDownloadWithoutDisablingSSLHostnameVerification(c *gc.C) {
Expand All @@ -84,36 +81,18 @@ func (s *DownloadSuite) TestDownloadWithDisablingSSLHostnameVerification(c *gc.C

func (s *DownloadSuite) TestDownloadError(c *gc.C) {
gitjujutesting.Server.Response(404, nil, nil)
d := downloader.StartDownload(
downloader.Request{
URL: s.URL(c, "/archive.tgz"),
TargetDir: c.MkDir(),
},
downloader.NewHTTPBlobOpener(utils.VerifySSLHostnames),
)
status := <-d.Done()
c.Assert(status.File, gc.IsNil)
c.Assert(status.Err, gc.ErrorMatches, `cannot download ".*": bad http response: 404 Not Found`)
}

func (s *DownloadSuite) TestStop(c *gc.C) {
tmp := c.MkDir()
d := downloader.StartDownload(
downloader.Request{
URL: s.URL(c, "/x.tgz"),
URL: s.URL(c, "/archive.tgz"),
TargetDir: tmp,
},
downloader.NewHTTPBlobOpener(utils.VerifySSLHostnames),
)
d.Stop()
select {
case status := <-d.Done():
c.Fatalf("received status %#v after stop", status)
case <-time.After(testing.ShortWait):
}
infos, err := ioutil.ReadDir(tmp)
c.Assert(err, jc.ErrorIsNil)
c.Assert(infos, gc.HasLen, 0)
filename, err := d.Wait()
c.Assert(filename, gc.Equals, "")
c.Assert(err, gc.ErrorMatches, `bad http response: 404 Not Found`)
checkDirEmpty(c, tmp)
}

func (s *DownloadSuite) TestVerifyValid(c *gc.C) {
Expand All @@ -131,11 +110,10 @@ func (s *DownloadSuite) TestVerifyValid(c *gc.C) {
},
downloader.NewHTTPBlobOpener(utils.VerifySSLHostnames),
)
status := <-dl.Done()
c.Assert(status.Err, jc.ErrorIsNil)

filename, err := dl.Wait()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks much nicer.

c.Assert(err, jc.ErrorIsNil)
c.Check(filename, gc.Not(gc.Equals), "")
stub.CheckCallNames(c, "Verify")
stub.CheckCall(c, 0, "Verify", status.File)
}

func (s *DownloadSuite) TestVerifyInvalid(c *gc.C) {
Expand All @@ -154,19 +132,40 @@ func (s *DownloadSuite) TestVerifyInvalid(c *gc.C) {
},
downloader.NewHTTPBlobOpener(utils.VerifySSLHostnames),
)
status := <-dl.Done()

c.Check(errors.Cause(status.Err), gc.Equals, invalid)
filename, err := dl.Wait()
c.Check(filename, gc.Equals, "")
c.Check(errors.Cause(err), gc.Equals, invalid)
stub.CheckCallNames(c, "Verify")
stub.CheckCall(c, 0, "Verify", status.File)
checkDirEmpty(c, tmp)
}

func (s *DownloadSuite) TestAbort(c *gc.C) {
tmp := c.MkDir()
gitjujutesting.Server.Response(200, nil, []byte("archive"))
abort := make(chan struct{})
close(abort)
dl := downloader.StartDownload(
downloader.Request{
URL: s.URL(c, "/archive.tgz"),
TargetDir: tmp,
Abort: abort,
},
downloader.NewHTTPBlobOpener(utils.VerifySSLHostnames),
)
filename, err := dl.Wait()
c.Check(filename, gc.Equals, "")
c.Check(err, gc.ErrorMatches, "download aborted")
checkDirEmpty(c, tmp)
}

func assertFileContents(c *gc.C, filename, expect string) {
got, err := ioutil.ReadFile(filename)
c.Assert(err, jc.ErrorIsNil)
c.Check(string(got), gc.Equals, expect)
}

func assertFileContents(c *gc.C, f *os.File, expect string) {
got, err := ioutil.ReadAll(f)
func checkDirEmpty(c *gc.C, dir string) {
files, err := ioutil.ReadDir(dir)
c.Assert(err, jc.ErrorIsNil)
if !c.Check(string(got), gc.Equals, expect) {
info, err := f.Stat()
c.Assert(err, jc.ErrorIsNil)
c.Logf("info %#v", info)
}
c.Check(files, gc.HasLen, 0)
}