Skip to content

Commit

Permalink
Wrap progress updates in a mutex (#1402)
Browse files Browse the repository at this point in the history
Atomically incrementing the number of bytes written isn't sufficient if
we're sending the updates out of order.

I ran the progress tests 50 times after this change and they passed.
  • Loading branch information
jonjohnsonjr committed Jun 29, 2022
1 parent 59b5c06 commit 86f0c4a
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 78 deletions.
20 changes: 10 additions & 10 deletions pkg/v1/remote/multi_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,32 +87,32 @@ func MultiWrite(m map[name.Reference]Taggable, options ...Option) (rerr error) {
return err
}
w := writer{
repo: repo,
client: &http.Client{Transport: tr},
context: o.context,
updates: o.updates,
lastUpdate: &v1.Update{},
backoff: o.retryBackoff,
predicate: o.retryPredicate,
repo: repo,
client: &http.Client{Transport: tr},
context: o.context,
backoff: o.retryBackoff,
predicate: o.retryPredicate,
}

// Collect the total size of blobs and manifests we're about to write.
if o.updates != nil {
w.progress = &progress{updates: o.updates}
w.progress.lastUpdate = &v1.Update{}
defer close(o.updates)
defer func() { _ = sendError(o.updates, rerr) }()
defer func() { _ = w.progress.err(rerr) }()
for _, b := range blobs {
size, err := b.Size()
if err != nil {
return err
}
w.lastUpdate.Total += size
w.progress.total(size)
}
countManifest := func(t Taggable) error {
b, err := t.RawManifest()
if err != nil {
return err
}
w.lastUpdate.Total += int64(len(b))
w.progress.total(int64(len(b)))
return nil
}
for _, i := range images {
Expand Down
69 changes: 69 additions & 0 deletions pkg/v1/remote/progress.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright 2022 Google LLC All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package remote

import (
"io"
"sync"
"sync/atomic"

v1 "github.com/google/go-containerregistry/pkg/v1"
)

type progress struct {
sync.Mutex
updates chan<- v1.Update
lastUpdate *v1.Update
}

func (p *progress) total(delta int64) {
atomic.AddInt64(&p.lastUpdate.Total, delta)
}

func (p *progress) complete(delta int64) {
p.Lock()
defer p.Unlock()
p.updates <- v1.Update{
Total: p.lastUpdate.Total,
Complete: atomic.AddInt64(&p.lastUpdate.Complete, delta),
}
}

func (p *progress) err(err error) error {
if err != nil && p.updates != nil {
p.updates <- v1.Update{Error: err}
}
return err
}

type progressReader struct {
rc io.ReadCloser

count *int64 // number of bytes this reader has read, to support resetting on retry.
progress *progress
}

func (r *progressReader) Read(b []byte) (int, error) {
n, err := r.rc.Read(b)
if err != nil {
return n, err
}
atomic.AddInt64(r.count, int64(n))
// TODO: warn/debug log if sending takes too long, or if sending is blocked while context is canceled.
r.progress.complete(int64(n))
return n, nil
}

func (r *progressReader) Close() error { return r.rc.Close() }
103 changes: 35 additions & 68 deletions pkg/v1/remote/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"net/http"
"net/url"
"strings"
"sync/atomic"

"github.com/google/go-containerregistry/internal/redact"
"github.com/google/go-containerregistry/internal/retry"
Expand All @@ -49,20 +48,21 @@ func Write(ref name.Reference, img v1.Image, options ...Option) (rerr error) {
return err
}

var lastUpdate *v1.Update
var p *progress
if o.updates != nil {
lastUpdate = &v1.Update{}
lastUpdate.Total, err = countImage(img, o.allowNondistributableArtifacts)
p = &progress{updates: o.updates}
p.lastUpdate = &v1.Update{}
p.lastUpdate.Total, err = countImage(img, o.allowNondistributableArtifacts)
if err != nil {
return err
}
defer close(o.updates)
defer func() { _ = sendError(o.updates, rerr) }()
defer func() { _ = p.err(rerr) }()
}
return writeImage(o.context, ref, img, o, lastUpdate)
return writeImage(o.context, ref, img, o, p)
}

func writeImage(ctx context.Context, ref name.Reference, img v1.Image, o *options, lastUpdate *v1.Update) error {
func writeImage(ctx context.Context, ref name.Reference, img v1.Image, o *options, progress *progress) error {
ls, err := img.Layers()
if err != nil {
return err
Expand All @@ -73,13 +73,12 @@ func writeImage(ctx context.Context, ref name.Reference, img v1.Image, o *option
return err
}
w := writer{
repo: ref.Context(),
client: &http.Client{Transport: tr},
context: ctx,
updates: o.updates,
lastUpdate: lastUpdate,
backoff: o.retryBackoff,
predicate: o.retryPredicate,
repo: ref.Context(),
client: &http.Client{Transport: tr},
context: ctx,
progress: progress,
backoff: o.retryBackoff,
predicate: o.retryPredicate,
}

// Upload individual blobs and collect any errors.
Expand Down Expand Up @@ -174,17 +173,9 @@ type writer struct {
client *http.Client
context context.Context

updates chan<- v1.Update
lastUpdate *v1.Update
backoff Backoff
predicate retry.Predicate
}

func sendError(ch chan<- v1.Update, err error) error {
if err != nil && ch != nil {
ch <- v1.Update{Error: err}
}
return err
progress *progress
backoff Backoff
predicate retry.Predicate
}

// url returns a url.Url for the specified path in the context of this remote image reference.
Expand Down Expand Up @@ -310,30 +301,6 @@ func (w *writer) initiateUpload(from, mount, origin string) (location string, mo
}
}

type progressReader struct {
rc io.ReadCloser

count *int64 // number of bytes this reader has read, to support resetting on retry.
updates chan<- v1.Update
lastUpdate *v1.Update
}

func (r *progressReader) Read(b []byte) (int, error) {
n, err := r.rc.Read(b)
if err != nil {
return n, err
}
atomic.AddInt64(r.count, int64(n))
// TODO: warn/debug log if sending takes too long, or if sending is blocked while context is cancelled.
r.updates <- v1.Update{
Total: r.lastUpdate.Total,
Complete: atomic.AddInt64(&r.lastUpdate.Complete, int64(n)),
}
return n, nil
}

func (r *progressReader) Close() error { return r.rc.Close() }

// streamBlob streams the contents of the blob to the specified location.
// On failure, this will return an error. On success, this will return the location
// header indicating how to commit the streamed blob.
Expand All @@ -350,19 +317,18 @@ func (w *writer) streamBlob(ctx context.Context, layer v1.Layer, streamLocation
}

getBody := layer.Compressed
if w.updates != nil {
if w.progress != nil {
var count int64
blob = &progressReader{rc: blob, updates: w.updates, lastUpdate: w.lastUpdate, count: &count}
blob = &progressReader{rc: blob, progress: w.progress, count: &count}
getBody = func() (io.ReadCloser, error) {
blob, err := layer.Compressed()
if err != nil {
return nil, err
}
return &progressReader{rc: blob, updates: w.updates, lastUpdate: w.lastUpdate, count: &count}, nil
return &progressReader{rc: blob, progress: w.progress, count: &count}, nil
}
reset = func() {
atomic.AddInt64(&w.lastUpdate.Complete, -count)
w.updates <- *w.lastUpdate
w.progress.complete(-count)
}
}

Expand Down Expand Up @@ -419,13 +385,10 @@ func (w *writer) commitBlob(location, digest string) error {

// incrProgress increments and sends a progress update, if WithProgress is used.
func (w *writer) incrProgress(written int64) {
if w.updates == nil {
if w.progress == nil {
return
}
w.updates <- v1.Update{
Total: w.lastUpdate.Total,
Complete: atomic.AddInt64(&w.lastUpdate.Complete, written),
}
w.progress.complete(written)
}

// uploadOne performs a complete upload of a single layer.
Expand Down Expand Up @@ -546,7 +509,7 @@ func (w *writer) writeIndex(ctx context.Context, ref name.Reference, ii v1.Image
if err != nil {
return err
}
if err := writeImage(ctx, ref, img, o, w.lastUpdate); err != nil {
if err := writeImage(ctx, ref, img, o, w.progress); err != nil {
return err
}
default:
Expand Down Expand Up @@ -689,19 +652,21 @@ func WriteIndex(ref name.Reference, ii v1.ImageIndex, options ...Option) (rerr e
repo: ref.Context(),
client: &http.Client{Transport: tr},
context: o.context,
updates: o.updates,
backoff: o.retryBackoff,
predicate: o.retryPredicate,
}

if o.updates != nil {
w.lastUpdate = &v1.Update{}
w.lastUpdate.Total, err = countIndex(ii, o.allowNondistributableArtifacts)
w.progress = &progress{updates: o.updates}
w.progress.lastUpdate = &v1.Update{}

defer close(o.updates)
defer func() { w.progress.err(rerr) }()

w.progress.lastUpdate.Total, err = countIndex(ii, o.allowNondistributableArtifacts)
if err != nil {
return err
}
defer close(o.updates)
defer func() { sendError(o.updates, rerr) }()
}

return w.writeIndex(o.context, ref, ii, options...)
Expand Down Expand Up @@ -830,14 +795,16 @@ func WriteLayer(repo name.Repository, layer v1.Layer, options ...Option) (rerr e
repo: repo,
client: &http.Client{Transport: tr},
context: o.context,
updates: o.updates,
backoff: o.retryBackoff,
predicate: o.retryPredicate,
}

if o.updates != nil {
w.progress = &progress{updates: o.updates}
w.progress.lastUpdate = &v1.Update{}

defer close(o.updates)
defer func() { sendError(o.updates, rerr) }()
defer func() { w.progress.err(rerr) }()

// TODO: support streaming layers which update the total count as they write.
if _, ok := layer.(*stream.Layer); ok {
Expand All @@ -847,7 +814,7 @@ func WriteLayer(repo name.Repository, layer v1.Layer, options ...Option) (rerr e
if err != nil {
return err
}
w.lastUpdate = &v1.Update{Total: size}
w.progress.total(size)
}
return w.uploadOne(o.context, layer)
}
Expand Down

0 comments on commit 86f0c4a

Please sign in to comment.