Skip to content

Commit

Permalink
Reverted newDigestingReader to validate incoming blob during decryption
Browse files Browse the repository at this point in the history
Signed-off-by: Brandon Lum <lumjjb@gmail.com>
  • Loading branch information
lumjjb committed Nov 24, 2019
1 parent 372076d commit 53d9e53
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 36 deletions.
53 changes: 22 additions & 31 deletions copy/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ type digestingReader struct {
expectedDigest digest.Digest
validationFailed bool
validationSucceeded bool
validateDigests bool
}

var (
Expand All @@ -57,40 +56,36 @@ var compressionBufferSize = 1048576
// newDigestingReader returns an io.Reader implementation with contents of source, which will eventually return a non-EOF error
// or set validationSucceeded/validationFailed to true if the source stream does/does not match expectedDigest.
// (neither is set if EOF is never reached).
func newDigestingReader(source io.Reader, expectedDigest digest.Digest, validateDigests bool) (*digestingReader, error) {
func newDigestingReader(source io.Reader, expectedDigest digest.Digest) (*digestingReader, error) {
var digester digest.Digester
if validateDigests {
if err := expectedDigest.Validate(); err != nil {
return nil, errors.Errorf("Invalid digest specification %s", expectedDigest)
}
digestAlgorithm := expectedDigest.Algorithm()
if !digestAlgorithm.Available() {
return nil, errors.Errorf("Invalid digest specification %s: unsupported digest algorithm %s", expectedDigest, digestAlgorithm)
}
digester = digestAlgorithm.Digester()
if err := expectedDigest.Validate(); err != nil {
return nil, errors.Errorf("Invalid digest specification %s", expectedDigest)
}
digestAlgorithm := expectedDigest.Algorithm()
if !digestAlgorithm.Available() {
return nil, errors.Errorf("Invalid digest specification %s: unsupported digest algorithm %s", expectedDigest, digestAlgorithm)
}
digester = digestAlgorithm.Digester()

return &digestingReader{
source: source,
digester: digester,
expectedDigest: expectedDigest,
validationFailed: false,
validateDigests: validateDigests,
}, nil
}

func (d *digestingReader) Read(p []byte) (int, error) {
n, err := d.source.Read(p)
if d.validateDigests {
if n > 0 {
if n2, err := d.digester.Hash().Write(p[:n]); n2 != n || err != nil {
// Coverage: This should not happen, the hash.Hash interface requires
// d.digest.Write to never return an error, and the io.Writer interface
// requires n2 == len(input) if no error is returned.
return 0, errors.Wrapf(err, "Error updating digest during verification: %d vs. %d", n2, n)
}
if n > 0 {
if n2, err := d.digester.Hash().Write(p[:n]); n2 != n || err != nil {
// Coverage: This should not happen, the hash.Hash interface requires
// d.digest.Write to never return an error, and the io.Writer interface
// requires n2 == len(input) if no error is returned.
return 0, errors.Wrapf(err, "Error updating digest during verification: %d vs. %d", n2, n)
}
}
if err == io.EOF && d.validateDigests {
if err == io.EOF {
actualDigest := d.digester.Digest()
if actualDigest != d.expectedDigest {
d.validationFailed = true
Expand Down Expand Up @@ -1154,16 +1149,20 @@ func (c *copier) copyBlobFromStream(ctx context.Context, srcStream io.Reader, sr
// Note that for this check we don't use the stronger "validationSucceeded" indicator, because
// dest.PutBlob may detect that the layer already exists, in which case we don't
// read stream to the end, and validation does not happen.
digestingReader, err := newDigestingReader(srcStream, srcInfo.Digest)
if err != nil {
return types.BlobInfo{}, errors.Wrapf(err, "Error preparing to verify blob %s", srcInfo.Digest)
}

var destStream io.Reader = digestingReader
var decrypted bool
var err error
if isOciEncrypted(srcInfo.MediaType) && c.ociDecryptConfig != nil {
newDesc := imgspecv1.Descriptor{
Annotations: srcInfo.Annotations,
}

var d digest.Digest
srcStream, d, err = ocicrypt.DecryptLayer(c.ociDecryptConfig, srcStream, newDesc, false)
destStream, d, err = ocicrypt.DecryptLayer(c.ociDecryptConfig, destStream, newDesc, false)
if err != nil {
return types.BlobInfo{}, errors.Wrapf(err, "Error decrypting layer %s", srcInfo.Digest)
}
Expand All @@ -1178,14 +1177,6 @@ func (c *copier) copyBlobFromStream(ctx context.Context, srcStream io.Reader, sr
decrypted = true
}

validateDigest := srcInfo.Digest != ""

digestingReader, err := newDigestingReader(srcStream, srcInfo.Digest, validateDigest)
if err != nil {
return types.BlobInfo{}, errors.Wrapf(err, "Error preparing to verify blob %s", srcInfo.Digest)
}
var destStream io.Reader = digestingReader

// === Detect compression of the input stream.
// This requires us to “peek ahead” into the stream to read the initial part, which requires us to chain through another io.Reader returned by DetectCompression.
compressionFormat, decompressor, destStream, err := compression.DetectCompressionFormat(destStream) // We could skip this in some cases, but let's keep the code path uniform
Expand Down
8 changes: 4 additions & 4 deletions copy/copy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestNewDigestingReader(t *testing.T) {
"sha256:0", // Invalid hex value
"sha256:01", // Invalid length of hex value
} {
_, err := newDigestingReader(source, input, true)
_, err := newDigestingReader(source, input)
assert.Error(t, err, input.String())
}
}
Expand All @@ -42,7 +42,7 @@ func TestDigestingReaderRead(t *testing.T) {
// Valid input
for _, c := range cases {
source := bytes.NewReader(c.input)
reader, err := newDigestingReader(source, c.digest, true)
reader, err := newDigestingReader(source, c.digest)
require.NoError(t, err, c.digest.String())
dest := bytes.Buffer{}
n, err := io.Copy(&dest, reader)
Expand All @@ -55,7 +55,7 @@ func TestDigestingReaderRead(t *testing.T) {
// Modified input
for _, c := range cases {
source := bytes.NewReader(bytes.Join([][]byte{c.input, []byte("x")}, nil))
reader, err := newDigestingReader(source, c.digest, true)
reader, err := newDigestingReader(source, c.digest)
require.NoError(t, err, c.digest.String())
dest := bytes.Buffer{}
_, err = io.Copy(&dest, reader)
Expand All @@ -66,7 +66,7 @@ func TestDigestingReaderRead(t *testing.T) {
// Truncated input
for _, c := range cases {
source := bytes.NewReader(c.input)
reader, err := newDigestingReader(source, c.digest, true)
reader, err := newDigestingReader(source, c.digest)
require.NoError(t, err, c.digest.String())
if len(c.input) != 0 {
dest := bytes.Buffer{}
Expand Down
2 changes: 1 addition & 1 deletion copy/encrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"github.com/containers/image/v5/types"
)

// isOciEncrypted returns if a mediatype is encrypted
// isOciEncrypted returns a bool indicating if a mediatype is encrypted
// This function will be moved to be part of OCI spec when adopted.
func isOciEncrypted(mediatype string) bool {
return strings.HasSuffix(mediatype, "+encrypted")
Expand Down

0 comments on commit 53d9e53

Please sign in to comment.