Skip to content

Commit

Permalink
ociregistry/ocifilter: add AccessChecker
Browse files Browse the repository at this point in the history
This provides a more general mechanism than `Select`,
allowing a client full control over the error that's returned.

Signed-off-by: Roger Peppe <rogpeppe@gmail.com>
Change-Id: Ia783c0d8efa139deb9ddebbda57b8decf28d6afa
Dispatch-Trailer: {"type":"trybot","CL":1191593,"patchset":2,"ref":"refs/changes/93/1191593/2","targetBranch":"main"}
  • Loading branch information
rogpeppe authored and porcuepine committed Apr 4, 2024
1 parent a45a207 commit 8f40bcf
Show file tree
Hide file tree
Showing 3 changed files with 325 additions and 69 deletions.
174 changes: 110 additions & 64 deletions ociregistry/ocifilter/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,161 +21,207 @@ import (
"cuelabs.dev/go/oci/ociregistry"
)

// Select returns a wrapper for r that provides only
// repositories for which allow returns true.
// AccessKind
type AccessKind int

const (
// [ociregistry.Reader] methods.
AccessRead AccessKind = iota

// [ociregistry.Writer] methods.
AccessWrite

// [ociregistry.Deleter] methods.
AccessDelete

// [ociregistry.Lister] methods.
AccessList
)

// AccessChecker returns a wrapper for r that invokes check
// to check access before calling an underlying method. Only if check succeeds will
// the underlying method be called.
//
// Requests for disallowed repositories will return ErrNameUnknown
// errors on read and ErrDenied on write.
func Select(r ociregistry.Interface, allow func(repoName string) bool) ociregistry.Interface {
return &selectRegistry{
allow: allow,
// The check function is invoked with the name of the repository being
// accessed (or "*" for Repositories), and the kind of access required.
// For some methods (e.g. Mount), check might be invoked more than
// once for a given repository.
//
// When invoking the Repositories method, check is invoked for each repository in
// the iteration - the repository will be omitted if check returns an error.
func AccessChecker(r ociregistry.Interface, check func(repoName string, access AccessKind) error) ociregistry.Interface {
return &accessCheckerRegistry{
check: check,
r: r,
}
}

type selectRegistry struct {
type accessCheckerRegistry struct {
// Embed Funcs rather than the interface directly so that
// if new methods are added and selectRegistry isn't updated,
// we fall back to returning an error rather than passing through the method.
*ociregistry.Funcs
allow func(repoName string) bool
check func(repoName string, kind AccessKind) error
r ociregistry.Interface
}

func (r *selectRegistry) GetBlob(ctx context.Context, repo string, digest ociregistry.Digest) (ociregistry.BlobReader, error) {
if !r.allow(repo) {
return nil, ociregistry.ErrNameUnknown
// Select returns a wrapper for r that provides only
// repositories for which allow returns true.
//
// Requests for disallowed repositories will return ErrNameUnknown
// errors on read and ErrDenied on write.
func Select(r ociregistry.Interface, allow func(repoName string) bool) ociregistry.Interface {
return AccessChecker(r, func(repoName string, access AccessKind) error {
if allow(repoName) {
return nil
}
if access == AccessWrite {
return ociregistry.ErrDenied
}
if access == AccessList && repoName == "*" {
return nil
}
return ociregistry.ErrNameUnknown
})
}

func (r *accessCheckerRegistry) GetBlob(ctx context.Context, repo string, digest ociregistry.Digest) (ociregistry.BlobReader, error) {
if err := r.check(repo, AccessRead); err != nil {
return nil, err
}
return r.r.GetBlob(ctx, repo, digest)
}

func (r *selectRegistry) GetBlobRange(ctx context.Context, repo string, digest ociregistry.Digest, offset0, offset1 int64) (ociregistry.BlobReader, error) {
if !r.allow(repo) {
return nil, ociregistry.ErrNameUnknown
func (r *accessCheckerRegistry) GetBlobRange(ctx context.Context, repo string, digest ociregistry.Digest, offset0, offset1 int64) (ociregistry.BlobReader, error) {
if err := r.check(repo, AccessRead); err != nil {
return nil, err
}
return r.r.GetBlobRange(ctx, repo, digest, offset0, offset1)
}

func (r *selectRegistry) GetManifest(ctx context.Context, repo string, digest ociregistry.Digest) (ociregistry.BlobReader, error) {
if !r.allow(repo) {
return nil, ociregistry.ErrNameUnknown
func (r *accessCheckerRegistry) GetManifest(ctx context.Context, repo string, digest ociregistry.Digest) (ociregistry.BlobReader, error) {
if err := r.check(repo, AccessRead); err != nil {
return nil, err
}
return r.r.GetManifest(ctx, repo, digest)
}

func (r *selectRegistry) GetTag(ctx context.Context, repo string, tagName string) (ociregistry.BlobReader, error) {
if !r.allow(repo) {
return nil, ociregistry.ErrNameUnknown
func (r *accessCheckerRegistry) GetTag(ctx context.Context, repo string, tagName string) (ociregistry.BlobReader, error) {
if err := r.check(repo, AccessRead); err != nil {
return nil, err
}
return r.r.GetTag(ctx, repo, tagName)
}

func (r *selectRegistry) ResolveBlob(ctx context.Context, repo string, digest ociregistry.Digest) (ociregistry.Descriptor, error) {
if !r.allow(repo) {
return ociregistry.Descriptor{}, ociregistry.ErrNameUnknown
func (r *accessCheckerRegistry) ResolveBlob(ctx context.Context, repo string, digest ociregistry.Digest) (ociregistry.Descriptor, error) {
if err := r.check(repo, AccessRead); err != nil {
return ociregistry.Descriptor{}, err
}
return r.r.ResolveBlob(ctx, repo, digest)
}

func (r *selectRegistry) ResolveManifest(ctx context.Context, repo string, digest ociregistry.Digest) (ociregistry.Descriptor, error) {
if !r.allow(repo) {
return ociregistry.Descriptor{}, ociregistry.ErrNameUnknown
func (r *accessCheckerRegistry) ResolveManifest(ctx context.Context, repo string, digest ociregistry.Digest) (ociregistry.Descriptor, error) {
if err := r.check(repo, AccessRead); err != nil {
return ociregistry.Descriptor{}, err
}
return r.r.ResolveManifest(ctx, repo, digest)
}

func (r *selectRegistry) ResolveTag(ctx context.Context, repo string, tagName string) (ociregistry.Descriptor, error) {
if !r.allow(repo) {
return ociregistry.Descriptor{}, ociregistry.ErrNameUnknown
func (r *accessCheckerRegistry) ResolveTag(ctx context.Context, repo string, tagName string) (ociregistry.Descriptor, error) {
if err := r.check(repo, AccessRead); err != nil {
return ociregistry.Descriptor{}, err
}
return r.r.ResolveTag(ctx, repo, tagName)
}

func (r *selectRegistry) PushBlob(ctx context.Context, repo string, desc ociregistry.Descriptor, rd io.Reader) (ociregistry.Descriptor, error) {
if !r.allow(repo) {
return ociregistry.Descriptor{}, ociregistry.ErrDenied
func (r *accessCheckerRegistry) PushBlob(ctx context.Context, repo string, desc ociregistry.Descriptor, rd io.Reader) (ociregistry.Descriptor, error) {
if err := r.check(repo, AccessWrite); err != nil {
return ociregistry.Descriptor{}, err
}
return r.r.PushBlob(ctx, repo, desc, rd)
}

func (r *selectRegistry) PushBlobChunked(ctx context.Context, repo string, chunkSize int) (ociregistry.BlobWriter, error) {
if !r.allow(repo) {
return nil, ociregistry.ErrDenied
func (r *accessCheckerRegistry) PushBlobChunked(ctx context.Context, repo string, chunkSize int) (ociregistry.BlobWriter, error) {
if err := r.check(repo, AccessWrite); err != nil {
return nil, err
}
return r.r.PushBlobChunked(ctx, repo, chunkSize)
}

func (r *selectRegistry) PushBlobChunkedResume(ctx context.Context, repo, id string, offset int64, chunkSize int) (ociregistry.BlobWriter, error) {
if !r.allow(repo) {
return nil, ociregistry.ErrDenied
func (r *accessCheckerRegistry) PushBlobChunkedResume(ctx context.Context, repo, id string, offset int64, chunkSize int) (ociregistry.BlobWriter, error) {
if err := r.check(repo, AccessWrite); err != nil {
return nil, err
}
return r.r.PushBlobChunkedResume(ctx, repo, id, offset, chunkSize)
}

func (r *selectRegistry) MountBlob(ctx context.Context, fromRepo, toRepo string, digest ociregistry.Digest) (ociregistry.Descriptor, error) {
if !r.allow(toRepo) {
return ociregistry.Descriptor{}, ociregistry.ErrDenied
func (r *accessCheckerRegistry) MountBlob(ctx context.Context, fromRepo, toRepo string, digest ociregistry.Digest) (ociregistry.Descriptor, error) {
if err := r.check(fromRepo, AccessRead); err != nil {
return ociregistry.Descriptor{}, err
}
if !r.allow(fromRepo) {
return ociregistry.Descriptor{}, ociregistry.ErrNameUnknown
if err := r.check(toRepo, AccessWrite); err != nil {
return ociregistry.Descriptor{}, err
}
return r.r.MountBlob(ctx, fromRepo, toRepo, digest)
}

func (r *selectRegistry) PushManifest(ctx context.Context, repo string, tag string, contents []byte, mediaType string) (ociregistry.Descriptor, error) {
if !r.allow(repo) {
return ociregistry.Descriptor{}, ociregistry.ErrDenied
func (r *accessCheckerRegistry) PushManifest(ctx context.Context, repo string, tag string, contents []byte, mediaType string) (ociregistry.Descriptor, error) {
if err := r.check(repo, AccessWrite); err != nil {
return ociregistry.Descriptor{}, err
}
return r.r.PushManifest(ctx, repo, tag, contents, mediaType)
}

func (r *selectRegistry) DeleteBlob(ctx context.Context, repo string, digest ociregistry.Digest) error {
if !r.allow(repo) {
return ociregistry.ErrNameUnknown
func (r *accessCheckerRegistry) DeleteBlob(ctx context.Context, repo string, digest ociregistry.Digest) error {
if err := r.check(repo, AccessDelete); err != nil {
return err
}
return r.r.DeleteBlob(ctx, repo, digest)
}

func (r *selectRegistry) DeleteManifest(ctx context.Context, repo string, digest ociregistry.Digest) error {
if !r.allow(repo) {
return ociregistry.ErrNameUnknown
func (r *accessCheckerRegistry) DeleteManifest(ctx context.Context, repo string, digest ociregistry.Digest) error {
if err := r.check(repo, AccessDelete); err != nil {
return err
}
return r.r.DeleteManifest(ctx, repo, digest)
}

func (r *selectRegistry) DeleteTag(ctx context.Context, repo string, name string) error {
if !r.allow(repo) {
return ociregistry.ErrNameUnknown
func (r *accessCheckerRegistry) DeleteTag(ctx context.Context, repo string, name string) error {
if err := r.check(repo, AccessDelete); err != nil {
return err
}
return r.r.DeleteTag(ctx, repo, name)
}

func (r *selectRegistry) Repositories(ctx context.Context, startAfter string) ociregistry.Seq[string] {
func (r *accessCheckerRegistry) Repositories(ctx context.Context, startAfter string) ociregistry.Seq[string] {
if err := r.check("*", AccessList); err != nil {
return ociregistry.ErrorSeq[string](err)
}
return func(yield func(string, error) bool) {
// TODO(go1.23): for name, err := range r.r.Repositories(ctx)
r.r.Repositories(ctx, startAfter)(func(repo string, err error) bool {
if err != nil {
yield("", err)
return false
}
if !r.allow(repo) {
if r.check(repo, AccessRead) != nil {
return true
}
return yield(repo, nil)
})
}
}

func (r *selectRegistry) Tags(ctx context.Context, repo, startAfter string) ociregistry.Seq[string] {
if !r.allow(repo) {
return ociregistry.ErrorSeq[string](ociregistry.ErrNameUnknown)
func (r *accessCheckerRegistry) Tags(ctx context.Context, repo, startAfter string) ociregistry.Seq[string] {
if err := r.check(repo, AccessList); err != nil {
return ociregistry.ErrorSeq[string](err)
}
return r.r.Tags(ctx, repo, startAfter)
}

func (r *selectRegistry) Referrers(ctx context.Context, repo string, digest ociregistry.Digest, artifactType string) ociregistry.Seq[ociregistry.Descriptor] {
if !r.allow(repo) {
return ociregistry.ErrorSeq[ociregistry.Descriptor](ociregistry.ErrNameUnknown)
func (r *accessCheckerRegistry) Referrers(ctx context.Context, repo string, digest ociregistry.Digest, artifactType string) ociregistry.Seq[ociregistry.Descriptor] {
if err := r.check(repo, AccessList); err != nil {
return ociregistry.ErrorSeq[ociregistry.Descriptor](err)
}
return r.r.Referrers(ctx, repo, digest, artifactType)
}
Loading

0 comments on commit 8f40bcf

Please sign in to comment.