Skip to content

Commit

Permalink
Merge pull request #6 from moby/ghsa-ambiguous-pull-by-digest
Browse files Browse the repository at this point in the history
[20.10] Validate digest in repo for pull by digest
  • Loading branch information
thaJeztah committed Oct 18, 2022
2 parents 3adff51 + 4b9902b commit 03df974
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 28 deletions.
107 changes: 101 additions & 6 deletions distribution/manifest.go
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"io/ioutil"
"strings"

"github.com/containerd/containerd/content"
"github.com/containerd/containerd/errdefs"
Expand All @@ -15,15 +16,22 @@ import (
"github.com/docker/distribution/manifest/manifestlist"
"github.com/docker/distribution/manifest/schema1"
"github.com/docker/distribution/manifest/schema2"
"github.com/docker/distribution/reference"
"github.com/docker/docker/registry"
digest "github.com/opencontainers/go-digest"
specs "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)

// labelDistributionSource describes the source blob comes from.
const labelDistributionSource = "containerd.io/distribution.source"

// This is used by manifestStore to pare down the requirements to implement a
// full distribution.ManifestService, since `Get` is all we use here.
type manifestGetter interface {
Get(ctx context.Context, dgst digest.Digest, options ...distribution.ManifestServiceOption) (distribution.Manifest, error)
Exists(ctx context.Context, dgst digest.Digest) (bool, error)
}

type manifestStore struct {
Expand All @@ -40,15 +48,98 @@ type ContentStore interface {
content.Provider
Info(ctx context.Context, dgst digest.Digest) (content.Info, error)
Abort(ctx context.Context, ref string) error
Update(ctx context.Context, info content.Info, fieldpaths ...string) (content.Info, error)
}

func makeDistributionSourceLabel(ref reference.Named) (string, string) {
domain := reference.Domain(ref)
if domain == "" {
domain = registry.DefaultNamespace
}
repo := reference.Path(ref)

return fmt.Sprintf("%s.%s", labelDistributionSource, domain), repo
}

func (m *manifestStore) getLocal(ctx context.Context, desc specs.Descriptor) (distribution.Manifest, error) {
// Taken from https://github.com/containerd/containerd/blob/e079e4a155c86f07bbd602fe6753ecacc78198c2/remotes/docker/handler.go#L84-L108
func appendDistributionSourceLabel(originLabel, repo string) string {
repos := []string{}
if originLabel != "" {
repos = strings.Split(originLabel, ",")
}
repos = append(repos, repo)

// use empty string to present duplicate items
for i := 1; i < len(repos); i++ {
tmp, j := repos[i], i-1
for ; j >= 0 && repos[j] >= tmp; j-- {
if repos[j] == tmp {
tmp = ""
}
repos[j+1] = repos[j]
}
repos[j+1] = tmp
}

i := 0
for ; i < len(repos) && repos[i] == ""; i++ {
}

return strings.Join(repos[i:], ",")
}

func hasDistributionSource(label, repo string) bool {
sources := strings.Split(label, ",")
for _, s := range sources {
if s == repo {
return true
}
}
return false
}

func (m *manifestStore) getLocal(ctx context.Context, desc specs.Descriptor, ref reference.Named) (distribution.Manifest, error) {
ra, err := m.local.ReaderAt(ctx, desc)
if err != nil {
return nil, errors.Wrap(err, "error getting content store reader")
}
defer ra.Close()

distKey, distRepo := makeDistributionSourceLabel(ref)
info, err := m.local.Info(ctx, desc.Digest)
if err != nil {
return nil, errors.Wrap(err, "error getting content info")
}

if _, ok := ref.(reference.Canonical); ok {
// Since this is specified by digest...
// We know we have the content locally, we need to check if we've seen this content at the specified repository before.
// If we have, we can just return the manifest from the local content store.
// If we haven't, we need to check the remote repository to see if it has the content, otherwise we can end up returning
// a manifest that has never even existed in the remote before.
if !hasDistributionSource(info.Labels[distKey], distRepo) {
logrus.WithField("ref", ref).Debug("found manifest but no mataching source repo is listed, checking with remote")
exists, err := m.remote.Exists(ctx, desc.Digest)
if err != nil {
return nil, errors.Wrap(err, "error checking if remote exists")
}

if !exists {
return nil, errors.Wrapf(errdefs.ErrNotFound, "manifest %v not found", desc.Digest)
}

}
}

// Update the distribution sources since we now know the content exists in the remote.
if info.Labels == nil {
info.Labels = map[string]string{}
}
info.Labels[distKey] = appendDistributionSourceLabel(info.Labels[distKey], distRepo)
if _, err := m.local.Update(ctx, info, "labels."+distKey); err != nil {
logrus.WithError(err).WithField("ref", ref).Warn("Could not update content distribution source")
}

r := io.NewSectionReader(ra, 0, ra.Size())
data, err := ioutil.ReadAll(r)
if err != nil {
Expand All @@ -59,6 +150,7 @@ func (m *manifestStore) getLocal(ctx context.Context, desc specs.Descriptor) (di
if err != nil {
return nil, errors.Wrap(err, "error unmarshaling manifest from content store")
}

return manifest, nil
}

Expand All @@ -76,7 +168,7 @@ func (m *manifestStore) getMediaType(ctx context.Context, desc specs.Descriptor)
return mt, nil
}

func (m *manifestStore) Get(ctx context.Context, desc specs.Descriptor) (distribution.Manifest, error) {
func (m *manifestStore) Get(ctx context.Context, desc specs.Descriptor, ref reference.Named) (distribution.Manifest, error) {
l := log.G(ctx)

if desc.MediaType == "" {
Expand Down Expand Up @@ -104,7 +196,7 @@ func (m *manifestStore) Get(ctx context.Context, desc specs.Descriptor) (distrib
if err != nil {
if errdefs.IsAlreadyExists(err) {
var manifest distribution.Manifest
if manifest, err = m.getLocal(ctx, desc); err == nil {
if manifest, err = m.getLocal(ctx, desc, ref); err == nil {
return manifest, nil
}
}
Expand All @@ -126,7 +218,7 @@ func (m *manifestStore) Get(ctx context.Context, desc specs.Descriptor) (distrib

if w != nil {
// if `w` is nil here, something happened with the content store, so don't bother trying to persist.
if err := m.Put(ctx, manifest, desc, w); err != nil {
if err := m.Put(ctx, manifest, desc, w, ref); err != nil {
if err := m.local.Abort(ctx, key); err != nil {
l.WithError(err).Warn("error aborting content ingest")
}
Expand All @@ -136,7 +228,7 @@ func (m *manifestStore) Get(ctx context.Context, desc specs.Descriptor) (distrib
return manifest, nil
}

func (m *manifestStore) Put(ctx context.Context, manifest distribution.Manifest, desc specs.Descriptor, w content.Writer) error {
func (m *manifestStore) Put(ctx context.Context, manifest distribution.Manifest, desc specs.Descriptor, w content.Writer, ref reference.Named) error {
mt, payload, err := manifest.Payload()
if err != nil {
return err
Expand All @@ -148,7 +240,10 @@ func (m *manifestStore) Put(ctx context.Context, manifest distribution.Manifest,
return errors.Wrap(err, "error writing manifest to content store")
}

if err := w.Commit(ctx, desc.Size, desc.Digest); err != nil {
distKey, distSource := makeDistributionSourceLabel(ref)
if err := w.Commit(ctx, desc.Size, desc.Digest, content.WithLabels(map[string]string{
distKey: distSource,
})); err != nil {
return errors.Wrap(err, "error committing manifest to content store")
}
return nil
Expand Down
62 changes: 42 additions & 20 deletions distribution/manifest_test.go
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/docker/distribution/manifest/ocischema"
"github.com/docker/distribution/manifest/schema1"
"github.com/docker/distribution/manifest/schema2"
"github.com/docker/distribution/reference"
"github.com/google/go-cmp/cmp/cmpopts"
digest "github.com/opencontainers/go-digest"
specs "github.com/opencontainers/image-spec/specs-go/v1"
Expand All @@ -40,6 +41,11 @@ func (m *mockManifestGetter) Get(ctx context.Context, dgst digest.Digest, option
return manifest, nil
}

func (m *mockManifestGetter) Exists(ctx context.Context, dgst digest.Digest) (bool, error) {
_, ok := m.manifests[dgst]
return ok, nil
}

type memoryLabelStore struct {
mu sync.Mutex
labels map[digest.Digest]map[string]string
Expand Down Expand Up @@ -77,7 +83,9 @@ func (s *memoryLabelStore) Update(dgst digest.Digest, update map[string]string)
for k, v := range update {
labels[k] = v
}

if s.labels == nil {
s.labels = map[digest.Digest]map[string]string{}
}
s.labels[dgst] = labels

return labels, nil
Expand Down Expand Up @@ -126,7 +134,7 @@ func TestManifestStore(t *testing.T) {
assert.NilError(t, err)
dgst := digest.Canonical.FromBytes(serialized)

setupTest := func(t *testing.T) (specs.Descriptor, *mockManifestGetter, *manifestStore, content.Store, func(*testing.T)) {
setupTest := func(t *testing.T) (reference.Named, specs.Descriptor, *mockManifestGetter, *manifestStore, content.Store, func(*testing.T)) {
root, err := ioutil.TempDir("", strings.Replace(t.Name(), "/", "_", -1))
assert.NilError(t, err)
defer func() {
Expand All @@ -142,7 +150,10 @@ func TestManifestStore(t *testing.T) {
store := &manifestStore{local: cs, remote: mg}
desc := specs.Descriptor{Digest: dgst, MediaType: specs.MediaTypeImageManifest, Size: int64(len(serialized))}

return desc, mg, store, cs, func(t *testing.T) {
ref, err := reference.Parse("foo/bar")
assert.NilError(t, err)

return ref.(reference.Named), desc, mg, store, cs, func(t *testing.T) {
assert.Check(t, os.RemoveAll(root))
}
}
Expand Down Expand Up @@ -183,22 +194,22 @@ func TestManifestStore(t *testing.T) {
}

t.Run("no remote or local", func(t *testing.T) {
desc, _, store, cs, teardown := setupTest(t)
ref, desc, _, store, cs, teardown := setupTest(t)
defer teardown(t)

_, err = store.Get(ctx, desc)
_, err = store.Get(ctx, desc, ref)
checkIngest(t, cs, desc)
// This error is what our digest getter returns when it doesn't know about the manifest
assert.Error(t, err, distribution.ErrManifestUnknown{Tag: dgst.String()}.Error())
})

t.Run("no local cache", func(t *testing.T) {
desc, mg, store, cs, teardown := setupTest(t)
ref, desc, mg, store, cs, teardown := setupTest(t)
defer teardown(t)

mg.manifests[desc.Digest] = m

m2, err := store.Get(ctx, desc)
m2, err := store.Get(ctx, desc, ref)
checkIngest(t, cs, desc)
assert.NilError(t, err)
assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))
Expand All @@ -208,23 +219,34 @@ func TestManifestStore(t *testing.T) {
assert.NilError(t, err)
assert.Check(t, cmp.Equal(i.Digest, desc.Digest))

distKey, distSource := makeDistributionSourceLabel(ref)
assert.Check(t, hasDistributionSource(i.Labels[distKey], distSource))

// Now check again, this should not hit the remote
m2, err = store.Get(ctx, desc)
m2, err = store.Get(ctx, desc, ref)
checkIngest(t, cs, desc)
assert.NilError(t, err)
assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))
assert.Check(t, cmp.Equal(mg.gets, 1))

t.Run("digested", func(t *testing.T) {
ref, err := reference.WithDigest(ref, desc.Digest)
assert.NilError(t, err)

_, err = store.Get(ctx, desc, ref)
assert.NilError(t, err)
})
})

t.Run("with local cache", func(t *testing.T) {
desc, mg, store, cs, teardown := setupTest(t)
ref, desc, mg, store, cs, teardown := setupTest(t)
defer teardown(t)

// first add the manifest to the coontent store
writeManifest(t, cs, desc)

// now do the get
m2, err := store.Get(ctx, desc)
m2, err := store.Get(ctx, desc, ref)
checkIngest(t, cs, desc)
assert.NilError(t, err)
assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))
Expand All @@ -238,13 +260,13 @@ func TestManifestStore(t *testing.T) {
// This is for the case of pull by digest where we don't know the media type of the manifest until it's actually pulled.
t.Run("unknown media type", func(t *testing.T) {
t.Run("no cache", func(t *testing.T) {
desc, mg, store, cs, teardown := setupTest(t)
ref, desc, mg, store, cs, teardown := setupTest(t)
defer teardown(t)

mg.manifests[desc.Digest] = m
desc.MediaType = ""

m2, err := store.Get(ctx, desc)
m2, err := store.Get(ctx, desc, ref)
checkIngest(t, cs, desc)
assert.NilError(t, err)
assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))
Expand All @@ -253,27 +275,27 @@ func TestManifestStore(t *testing.T) {

t.Run("with cache", func(t *testing.T) {
t.Run("cached manifest has media type", func(t *testing.T) {
desc, mg, store, cs, teardown := setupTest(t)
ref, desc, mg, store, cs, teardown := setupTest(t)
defer teardown(t)

writeManifest(t, cs, desc)
desc.MediaType = ""

m2, err := store.Get(ctx, desc)
m2, err := store.Get(ctx, desc, ref)
checkIngest(t, cs, desc)
assert.NilError(t, err)
assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))
assert.Check(t, cmp.Equal(mg.gets, 0))
})

t.Run("cached manifest has no media type", func(t *testing.T) {
desc, mg, store, cs, teardown := setupTest(t)
ref, desc, mg, store, cs, teardown := setupTest(t)
defer teardown(t)

desc.MediaType = ""
writeManifest(t, cs, desc)

m2, err := store.Get(ctx, desc)
m2, err := store.Get(ctx, desc, ref)
checkIngest(t, cs, desc)
assert.NilError(t, err)
assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))
Expand All @@ -288,14 +310,14 @@ func TestManifestStore(t *testing.T) {
// Also makes sure the ingests are aborted.
t.Run("error persisting manifest", func(t *testing.T) {
t.Run("error on writer", func(t *testing.T) {
desc, mg, store, cs, teardown := setupTest(t)
ref, desc, mg, store, cs, teardown := setupTest(t)
defer teardown(t)
mg.manifests[desc.Digest] = m

csW := &testingContentStoreWrapper{ContentStore: store.local, errorOnWriter: errors.New("random error")}
store.local = csW

m2, err := store.Get(ctx, desc)
m2, err := store.Get(ctx, desc, ref)
checkIngest(t, cs, desc)
assert.NilError(t, err)
assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))
Expand All @@ -307,14 +329,14 @@ func TestManifestStore(t *testing.T) {
})

t.Run("error on commit", func(t *testing.T) {
desc, mg, store, cs, teardown := setupTest(t)
ref, desc, mg, store, cs, teardown := setupTest(t)
defer teardown(t)
mg.manifests[desc.Digest] = m

csW := &testingContentStoreWrapper{ContentStore: store.local, errorOnCommit: errors.New("random error")}
store.local = csW

m2, err := store.Get(ctx, desc)
m2, err := store.Get(ctx, desc, ref)
checkIngest(t, cs, desc)
assert.NilError(t, err)
assert.Check(t, cmp.DeepEqual(m, m2, cmpopts.IgnoreUnexported(ocischema.DeserializedManifest{})))
Expand Down

0 comments on commit 03df974

Please sign in to comment.