diff --git a/copy/copy.go b/copy/copy.go index 8751c7b048..74e84fd56e 100644 --- a/copy/copy.go +++ b/copy/copy.go @@ -39,7 +39,6 @@ type digestingReader struct { expectedDigest digest.Digest validationFailed bool validationSucceeded bool - validateDigests bool } var ( @@ -57,40 +56,36 @@ var compressionBufferSize = 1048576 // newDigestingReader returns an io.Reader implementation with contents of source, which will eventually return a non-EOF error // or set validationSucceeded/validationFailed to true if the source stream does/does not match expectedDigest. // (neither is set if EOF is never reached). -func newDigestingReader(source io.Reader, expectedDigest digest.Digest, validateDigests bool) (*digestingReader, error) { +func newDigestingReader(source io.Reader, expectedDigest digest.Digest) (*digestingReader, error) { var digester digest.Digester - if validateDigests { - if err := expectedDigest.Validate(); err != nil { - return nil, errors.Errorf("Invalid digest specification %s", expectedDigest) - } - digestAlgorithm := expectedDigest.Algorithm() - if !digestAlgorithm.Available() { - return nil, errors.Errorf("Invalid digest specification %s: unsupported digest algorithm %s", expectedDigest, digestAlgorithm) - } - digester = digestAlgorithm.Digester() + if err := expectedDigest.Validate(); err != nil { + return nil, errors.Errorf("Invalid digest specification %s", expectedDigest) } + digestAlgorithm := expectedDigest.Algorithm() + if !digestAlgorithm.Available() { + return nil, errors.Errorf("Invalid digest specification %s: unsupported digest algorithm %s", expectedDigest, digestAlgorithm) + } + digester = digestAlgorithm.Digester() + return &digestingReader{ source: source, digester: digester, expectedDigest: expectedDigest, validationFailed: false, - validateDigests: validateDigests, }, nil } func (d *digestingReader) Read(p []byte) (int, error) { n, err := d.source.Read(p) - if d.validateDigests { - if n > 0 { - if n2, err := d.digester.Hash().Write(p[:n]); n2 != n || err != nil { - // Coverage: This should not happen, the hash.Hash interface requires - // d.digest.Write to never return an error, and the io.Writer interface - // requires n2 == len(input) if no error is returned. - return 0, errors.Wrapf(err, "Error updating digest during verification: %d vs. %d", n2, n) - } + if n > 0 { + if n2, err := d.digester.Hash().Write(p[:n]); n2 != n || err != nil { + // Coverage: This should not happen, the hash.Hash interface requires + // d.digest.Write to never return an error, and the io.Writer interface + // requires n2 == len(input) if no error is returned. + return 0, errors.Wrapf(err, "Error updating digest during verification: %d vs. %d", n2, n) } } - if err == io.EOF && d.validateDigests { + if err == io.EOF { actualDigest := d.digester.Digest() if actualDigest != d.expectedDigest { d.validationFailed = true @@ -1154,16 +1149,20 @@ func (c *copier) copyBlobFromStream(ctx context.Context, srcStream io.Reader, sr // Note that for this check we don't use the stronger "validationSucceeded" indicator, because // dest.PutBlob may detect that the layer already exists, in which case we don't // read stream to the end, and validation does not happen. + digestingReader, err := newDigestingReader(srcStream, srcInfo.Digest) + if err != nil { + return types.BlobInfo{}, errors.Wrapf(err, "Error preparing to verify blob %s", srcInfo.Digest) + } + var destStream io.Reader = digestingReader var decrypted bool - var err error if isOciEncrypted(srcInfo.MediaType) && c.ociDecryptConfig != nil { newDesc := imgspecv1.Descriptor{ Annotations: srcInfo.Annotations, } var d digest.Digest - srcStream, d, err = ocicrypt.DecryptLayer(c.ociDecryptConfig, srcStream, newDesc, false) + destStream, d, err = ocicrypt.DecryptLayer(c.ociDecryptConfig, destStream, newDesc, false) if err != nil { return types.BlobInfo{}, errors.Wrapf(err, "Error decrypting layer %s", srcInfo.Digest) } @@ -1178,14 +1177,6 @@ func (c *copier) copyBlobFromStream(ctx context.Context, srcStream io.Reader, sr decrypted = true } - validateDigest := srcInfo.Digest != "" - - digestingReader, err := newDigestingReader(srcStream, srcInfo.Digest, validateDigest) - if err != nil { - return types.BlobInfo{}, errors.Wrapf(err, "Error preparing to verify blob %s", srcInfo.Digest) - } - var destStream io.Reader = digestingReader - // === Detect compression of the input stream. // This requires us to “peek ahead” into the stream to read the initial part, which requires us to chain through another io.Reader returned by DetectCompression. compressionFormat, decompressor, destStream, err := compression.DetectCompressionFormat(destStream) // We could skip this in some cases, but let's keep the code path uniform diff --git a/copy/copy_test.go b/copy/copy_test.go index a092337efe..e8efa6c837 100644 --- a/copy/copy_test.go +++ b/copy/copy_test.go @@ -25,7 +25,7 @@ func TestNewDigestingReader(t *testing.T) { "sha256:0", // Invalid hex value "sha256:01", // Invalid length of hex value } { - _, err := newDigestingReader(source, input, true) + _, err := newDigestingReader(source, input) assert.Error(t, err, input.String()) } } @@ -42,7 +42,7 @@ func TestDigestingReaderRead(t *testing.T) { // Valid input for _, c := range cases { source := bytes.NewReader(c.input) - reader, err := newDigestingReader(source, c.digest, true) + reader, err := newDigestingReader(source, c.digest) require.NoError(t, err, c.digest.String()) dest := bytes.Buffer{} n, err := io.Copy(&dest, reader) @@ -55,7 +55,7 @@ func TestDigestingReaderRead(t *testing.T) { // Modified input for _, c := range cases { source := bytes.NewReader(bytes.Join([][]byte{c.input, []byte("x")}, nil)) - reader, err := newDigestingReader(source, c.digest, true) + reader, err := newDigestingReader(source, c.digest) require.NoError(t, err, c.digest.String()) dest := bytes.Buffer{} _, err = io.Copy(&dest, reader) @@ -66,7 +66,7 @@ func TestDigestingReaderRead(t *testing.T) { // Truncated input for _, c := range cases { source := bytes.NewReader(c.input) - reader, err := newDigestingReader(source, c.digest, true) + reader, err := newDigestingReader(source, c.digest) require.NoError(t, err, c.digest.String()) if len(c.input) != 0 { dest := bytes.Buffer{} diff --git a/copy/encrypt.go b/copy/encrypt.go index d92a35d30d..a18d6f1518 100644 --- a/copy/encrypt.go +++ b/copy/encrypt.go @@ -6,7 +6,7 @@ import ( "github.com/containers/image/v5/types" ) -// isOciEncrypted returns if a mediatype is encrypted +// isOciEncrypted returns a bool indicating if a mediatype is encrypted // This function will be moved to be part of OCI spec when adopted. func isOciEncrypted(mediatype string) bool { return strings.HasSuffix(mediatype, "+encrypted")