From e0e7df7e2a1d77200a6216abad6b7f80c4dbe3b4 Mon Sep 17 00:00:00 2001 From: Oleksandr Krutko Date: Sun, 30 Jun 2024 18:57:05 +0300 Subject: [PATCH] introduce certitiface pool sctructure Signed-off-by: Oleksandr Krutko add a test for duplicate certs Signed-off-by: Oleksandr Krutko add comment for functions Signed-off-by: Oleksandr Krutko fix typos, add comments for PEM and Certificates pool functions Signed-off-by: Oleksandr Krutko improve functions naming in PEM logic Signed-off-by: Oleksandr Krutko improve PEM functions Signed-off-by: Oleksandr Krutko remove unused functions, add Option pattern Signed-off-by: Oleksandr Krutko align code with the project style Signed-off-by: Oleksandr Krutko --- pkg/bundle/source.go | 15 ++----- pkg/fspkg/package.go | 5 ++- pkg/util/cert_pool.go | 60 +++++++++++++++++++++++----- pkg/util/cert_pool_test.go | 4 +- pkg/util/pem.go | 80 ++++++++++++++++++-------------------- pkg/util/pem_test.go | 60 +++++++++++++++------------- test/env/data.go | 9 +++-- 7 files changed, 137 insertions(+), 96 deletions(-) diff --git a/pkg/bundle/source.go b/pkg/bundle/source.go index 3a432f5d..cd3a2abe 100644 --- a/pkg/bundle/source.go +++ b/pkg/bundle/source.go @@ -64,7 +64,7 @@ type bundleData struct { // is each bundle is concatenated together with a new line character. func (b *bundle) buildSourceBundle(ctx context.Context, bundle *trustapi.Bundle) (bundleData, error) { var resolvedBundle bundleData - var bundles []string + var certPool = util.NewCertPool(util.WithFilteredExpiredCerts(b.FilterExpiredCerts)) for _, source := range bundle.Spec.Sources { var ( @@ -99,27 +99,20 @@ func (b *bundle) buildSourceBundle(ctx context.Context, bundle *trustapi.Bundle) return bundleData{}, fmt.Errorf("failed to retrieve bundle from source: %w", err) } - opts := util.ValidateAndSanitizeOptions{FilterExpired: b.Options.FilterExpiredCerts} - sanitizedBundle, err := util.ValidateAndSanitizePEMBundleWithOptions([]byte(sourceData), opts) + err = util.ValidateAndSplitPEMBundle(certPool, []byte(sourceData)) if err != nil { return bundleData{}, fmt.Errorf("invalid PEM data in source: %w", err) } - bundles = append(bundles, string(sanitizedBundle)) } // NB: empty bundles are not valid so check and return an error if one somehow snuck through. - if len(bundles) == 0 { + if util.GetCertificatesQuantity(certPool) == 0 { return bundleData{}, fmt.Errorf("couldn't find any valid certificates in bundle") } - deduplicatedBundles, err := deduplicateBundles(bundles) - if err != nil { - return bundleData{}, err - } - - if err := resolvedBundle.populateData(deduplicatedBundles, bundle.Spec.Target); err != nil { + if err := resolvedBundle.populateData(util.AsPEMBundleStrings(certPool), bundle.Spec.Target); err != nil { return bundleData{}, err } diff --git a/pkg/fspkg/package.go b/pkg/fspkg/package.go index 253998e7..f608da19 100644 --- a/pkg/fspkg/package.go +++ b/pkg/fspkg/package.go @@ -64,7 +64,10 @@ func (p *Package) Clone() *Package { func (p *Package) Validate() error { // Ignore the sanitized bundle here and preserve the bundle as-is. // We'll sanitize later, when building a bundle on a reconcile. - _, err := util.ValidateAndSanitizePEMBundle([]byte(p.Bundle)) + + var certPool = util.NewCertPool(util.WithFilteredExpiredCerts(false)) + + err := util.ValidateAndSplitPEMBundle(certPool, []byte(p.Bundle)) if err != nil { return fmt.Errorf("package bundle failed validation: %w", err) } diff --git a/pkg/util/cert_pool.go b/pkg/util/cert_pool.go index 5d3394d7..e268c23a 100644 --- a/pkg/util/cert_pool.go +++ b/pkg/util/cert_pool.go @@ -17,6 +17,7 @@ limitations under the License. package util import ( + "crypto/sha256" "crypto/x509" "encoding/pem" "fmt" @@ -24,21 +25,36 @@ import ( ) // CertPool is a set of certificates. -type certPool struct { - certificates []*x509.Certificate - filterExpired bool +type CertPool struct { + certificatesHashes map[[32]byte]struct{} + certificates []*x509.Certificate + filterExpired bool +} + +type Option func(*CertPool) + +func WithFilteredExpiredCerts(filterExpired bool) Option { + return func(cp *CertPool) { + cp.filterExpired = filterExpired + } } // newCertPool returns a new, empty CertPool. -func newCertPool(filterExpired bool) *certPool { - return &certPool{ - certificates: make([]*x509.Certificate, 0), - filterExpired: filterExpired, +func NewCertPool(options ...Option) *CertPool { + var certPool = &CertPool{ + certificates: make([]*x509.Certificate, 0), + certificatesHashes: make(map[[32]byte]struct{}), } + + for _, option := range options { + option(certPool) + } + + return certPool } // Append certificate to a pool -func (cp *certPool) appendCertFromPEM(pemData []byte) error { +func (cp *CertPool) appendCertFromPEM(pemData []byte) error { if pemData == nil { return fmt.Errorf("certificate data can't be nil") } @@ -75,6 +91,10 @@ func (cp *certPool) appendCertFromPEM(pemData []byte) error { continue } + if cp.isDuplicate(certificate) { + continue + } + cp.certificates = append(cp.certificates, certificate) } @@ -82,7 +102,7 @@ func (cp *certPool) appendCertFromPEM(pemData []byte) error { } // Get PEM certificates from pool -func (cp *certPool) getCertsPEM() [][]byte { +func (cp *CertPool) getCertsPEM() [][]byte { var certsData [][]byte = make([][]byte, len(cp.certificates)) for i, cert := range cp.certificates { @@ -91,3 +111,25 @@ func (cp *certPool) getCertsPEM() [][]byte { return certsData } + +// Get certificates quantity in the certificates pool +func (cp *CertPool) size() int { + return len(cp.certificates) +} + +// Check deplicates of certificate in the certificates pool +func (cp *CertPool) isDuplicate(cert *x509.Certificate) bool { + hash := sha256.Sum256(cert.Raw) + // check existence of the hash + if _, ok := cp.certificatesHashes[hash]; !ok { + cp.certificatesHashes[hash] = struct{}{} + return false + } + + return true +} + +// Get the full list of x509 Certificates from the certificates pool +func (cp *CertPool) getCertsList() []*x509.Certificate { + return cp.certificates +} diff --git a/pkg/util/cert_pool_test.go b/pkg/util/cert_pool_test.go index 21e376a1..c77eb3aa 100644 --- a/pkg/util/cert_pool_test.go +++ b/pkg/util/cert_pool_test.go @@ -23,7 +23,7 @@ import ( ) func TestNewCertPool(t *testing.T) { - certPool := newCertPool(false) + certPool := NewCertPool(WithFilteredExpiredCerts(false)) if certPool == nil { t.Fatal("pool is nil") @@ -76,7 +76,7 @@ func TestAppendCertFromPEM(t *testing.T) { // populate certificates bundle for _, crt := range certificateList { - certPool := newCertPool(false) + certPool := NewCertPool(WithFilteredExpiredCerts(false)) if err := certPool.appendCertFromPEM([]byte(crt.certificate)); err != nil { t.Fatalf("error adding PEM certificate into pool %s", err) diff --git a/pkg/util/pem.go b/pkg/util/pem.go index 4a0fde0a..95558489 100644 --- a/pkg/util/pem.go +++ b/pkg/util/pem.go @@ -21,6 +21,7 @@ import ( "crypto/x509" "encoding/pem" "fmt" + "strings" ) // ValidateAndSanitizePEMBundle strictly validates a given input PEM bundle to confirm it contains @@ -42,56 +43,19 @@ import ( // contain (accidental) private information. They're also non-standard according to // https://www.rfc-editor.org/rfc/rfc7468 -type ValidateAndSanitizeOptions struct { - FilterExpired bool // If true, expired certificates will be filtered out -} - -// ValidateAndSanitizePEMBundle keeps the original function signature for backward compatibility -func ValidateAndSanitizePEMBundle(data []byte) ([]byte, error) { - opts := ValidateAndSanitizeOptions{ - FilterExpired: false, - } - return ValidateAndSanitizePEMBundleWithOptions(data, opts) -} - -// ValidateAndSplitPEMBundle keeps the original function signature for backward compatibility -func ValidateAndSplitPEMBundle(data []byte) ([][]byte, error) { - opts := ValidateAndSanitizeOptions{ - FilterExpired: false, - } - return ValidateAndSplitPEMBundleWithOptions(data, opts) -} - // See also https://github.com/golang/go/blob/5d5ed57b134b7a02259ff070864f753c9e601a18/src/crypto/x509/cert_pool.go#L201-L239 // An option to enable filtering of expired certificates is available. -func ValidateAndSanitizePEMBundleWithOptions(data []byte, opts ValidateAndSanitizeOptions) ([]byte, error) { - certificates, err := ValidateAndSplitPEMBundleWithOptions(data, opts) +func ValidateAndSplitPEMBundle(certPool *CertPool, data []byte) error { + err := certPool.appendCertFromPEM(data) if err != nil { - return nil, err - } - - if len(certificates) == 0 { - return nil, fmt.Errorf("bundle contains no PEM certificates") + return err } - return bytes.TrimSpace(bytes.Join(certificates, nil)), nil -} - -// ValidateAndSplitPEMBundleWithOptions takes a PEM bundle as input, validates it and -// returns the list of certificates as a slice, allowing them to be iterated over. -// This process involves performs deduplication of certificates to ensure -// no duplicated certificates in the bundle. -// For details of the validation performed, see the comment for ValidateAndSanitizePEMBundle -// An option to enable filtering of expired certificates is available. -func ValidateAndSplitPEMBundleWithOptions(data []byte, opts ValidateAndSanitizeOptions) ([][]byte, error) { - var certPool *certPool = newCertPool(opts.FilterExpired) // put PEM encoded certificate into a pool - - err := certPool.appendCertFromPEM(data) - if err != nil { - return nil, fmt.Errorf("invalid PEM block in bundle; invalid PEM certificate: %w", err) + if certPool.size() == 0 { + return fmt.Errorf("bundle contains no PEM certificates") } - return certPool.getCertsPEM(), nil + return nil } // DecodeX509CertificateChainBytes will decode a PEM encoded x509 Certificate chain. @@ -121,3 +85,33 @@ func DecodeX509CertificateChainBytes(certBytes []byte) ([]*x509.Certificate, err return certs, nil } + +// Get the split bundle of all certificates in the certificates pool as representation of [][]byte +func AsSplitPEMBundle(certPool *CertPool) [][]byte { + return certPool.getCertsPEM() +} + +// Get the split bundle of all certificates in the certificates pool as representation of []byte +func AsPEMBundleBytes(certPool *CertPool) []byte { + return bytes.TrimSpace(bytes.Join(certPool.getCertsPEM(), nil)) +} + +// Get the split bundle of all certificates in the certificates pool as representation of []string +func AsPEMBundleStrings(certPool *CertPool) []string { + var certList = make([]string, 0) + + for _, cert := range certPool.getCertsPEM() { + certList = append(certList, strings.TrimSpace(string(cert))) + } + + return certList +} + +// Get the list of all x509 Certificates in the certificates pool +func AsCertificateList(certPool *CertPool) []*x509.Certificate { + return certPool.getCertsList() +} + +func GetCertificatesQuantity(certPool *CertPool) int { + return certPool.size() +} diff --git a/pkg/util/pem_test.go b/pkg/util/pem_test.go index 34fda82d..c8fafb26 100644 --- a/pkg/util/pem_test.go +++ b/pkg/util/pem_test.go @@ -18,6 +18,7 @@ package util import ( "bytes" + "crypto/sha256" "crypto/x509" "strings" "testing" @@ -35,10 +36,12 @@ func TestValidateAndSanitizePEMBundle(t *testing.T) { } cases := map[string]struct { - parts []string - filterExpiredCerts bool - expectExpiredCerts bool - expectErr bool + parts []string + filterDuplicateCerts bool + filterExpiredCerts bool + expectExpiredCerts bool + expectErr bool + expectDuplicatesCerts bool }{ "valid bundle with all types of cert and no comments succeeds": { parts: []string{dummy.TestCertificate1, dummy.TestCertificate2, dummy.TestCertificate3}, @@ -90,15 +93,22 @@ func TestValidateAndSanitizePEMBundle(t *testing.T) { filterExpiredCerts: true, expectErr: true, }, + "duplicate certificate should be removed": { + parts: []string{dummy.TestCertificate1, dummy.JoinCerts(dummy.TestCertificate1, dummy.TestCertificate1), dummy.TestCertificate2, dummy.TestCertificate2}, + filterExpiredCerts: true, + expectErr: false, + expectDuplicatesCerts: true, + }, } for name, test := range cases { t.Run(name, func(t *testing.T) { - validateOpts := ValidateAndSanitizeOptions{FilterExpired: test.filterExpiredCerts} + _ = name + var certPool = NewCertPool(WithFilteredExpiredCerts(test.filterExpiredCerts)) inputBundle := []byte(strings.Join(test.parts, "\n")) - sanitizedBundleBytes, err := ValidateAndSanitizePEMBundleWithOptions(inputBundle, validateOpts) + err := ValidateAndSplitPEMBundle(certPool, inputBundle) if test.expectErr != (err != nil) { t.Fatalf("ValidateAndSanitizePEMBundle: expectErr: %v | err: %v", test.expectErr, err) @@ -108,22 +118,22 @@ func TestValidateAndSanitizePEMBundle(t *testing.T) { return } - if sanitizedBundleBytes == nil { + if GetCertificatesQuantity(certPool) == 0 { t.Fatalf("got no error from ValidateAndSanitizePEMBundle but sanitizedBundle was nil") } for _, strippable := range strippableText { - if bytes.Contains(sanitizedBundleBytes, strippable) { + if bytes.Contains(AsPEMBundleBytes(certPool), strippable) { // can't print the comment since it could be an invalid string t.Errorf("expected sanitizedBundle to not contain a comment but it did") } } - if !utf8.Valid(sanitizedBundleBytes) { + if !utf8.Valid(AsPEMBundleBytes(certPool)) { t.Error("expected sanitizedBundle to be valid UTF-8 but it wasn't") } - sanitizedBundle := string(sanitizedBundleBytes) + sanitizedBundle := string(AsPEMBundleBytes(certPool)) if strings.HasSuffix(sanitizedBundle, "\n") { t.Errorf("expected sanitizedBundle not to end with a newline") @@ -141,7 +151,7 @@ func TestValidateAndSanitizePEMBundle(t *testing.T) { } } - certs, err := ValidateAndSplitPEMBundleWithOptions(sanitizedBundleBytes, validateOpts) + certs := AsCertificateList(certPool) if err != nil { t.Errorf("failed to split already-validated bundle: %s", err) return @@ -150,28 +160,24 @@ func TestValidateAndSanitizePEMBundle(t *testing.T) { var expiredCerts []*x509.Certificate for _, cert := range certs { - parsedCerts, err := DecodeX509CertificateChainBytes(cert) - if err != nil { - t.Errorf("failed to decode split PEM cert: %s", err) - continue - } - - if len(parsedCerts) != 1 { - // shouldn't ever happen since we're decoding a single PEM cert - t.Errorf("got more than one parsed cert after splitting a PEM bundle") - continue - } - - parsedCert := parsedCerts[0] - - if parsedCert.NotAfter.Before(dummy.DummyInstant()) { - expiredCerts = append(expiredCerts, parsedCert) + if cert.NotAfter.Before(dummy.DummyInstant()) { + expiredCerts = append(expiredCerts, cert) } } if test.expectExpiredCerts != (len(expiredCerts) > 0) { t.Errorf("expectExpiredCerts=%v but got %d expired certs", test.expectExpiredCerts, len(expiredCerts)) } + + if test.expectDuplicatesCerts { + var hashes = make(map[[32]byte]struct{}) + for _, cert := range certs { + hash := sha256.Sum256(cert.Raw) + if _, ok := hashes[hash]; ok { + t.Errorf("expectDuplicatesCerts=%v but got duplicate certs", test.expectDuplicatesCerts) + } + } + } }) } } diff --git a/test/env/data.go b/test/env/data.go index c83fe5a4..67e3d3df 100644 --- a/test/env/data.go +++ b/test/env/data.go @@ -217,11 +217,13 @@ func CheckBundleSyncedStartsWith(ctx context.Context, cl client.Client, name str return fmt.Errorf("received data didn't start with expected data") } + var certPool = util.NewCertPool(util.WithFilteredExpiredCerts(false)) + remaining := strings.TrimPrefix(got, startingData) // check that there are a nonzero number of valid certs remaining - _, err := util.ValidateAndSanitizePEMBundle([]byte(remaining)) + err := util.ValidateAndSplitPEMBundle(certPool, []byte(remaining)) if err != nil { return fmt.Errorf("received data didn't have any valid certs after valid starting data: %w", err) } @@ -317,6 +319,7 @@ func EventuallyBundleHasSyncedAllNamespacesStartsWith(ctx context.Context, cl cl // CheckJKSFileSynced ensures that the given JKS data func CheckJKSFileSynced(jksData []byte, expectedPassword string, expectedCertPEMData string) error { reader := bytes.NewReader(jksData) + var certPool = util.NewCertPool(util.WithFilteredExpiredCerts(false)) ks := jks.New() @@ -325,7 +328,7 @@ func CheckJKSFileSynced(jksData []byte, expectedPassword string, expectedCertPEM return err } - expectedCertList, err := util.ValidateAndSplitPEMBundle([]byte(expectedCertPEMData)) + err = util.ValidateAndSplitPEMBundle(certPool, []byte(expectedCertPEMData)) if err != nil { return fmt.Errorf("invalid PEM data passed to CheckJKSFileSynced: %s", err) } @@ -334,7 +337,7 @@ func CheckJKSFileSynced(jksData []byte, expectedPassword string, expectedCertPEM // that the count is the same aliasCount := len(ks.Aliases()) - expectedPEMCount := len(expectedCertList) + expectedPEMCount := len(util.AsSplitPEMBundle(certPool)) if aliasCount != expectedPEMCount { return fmt.Errorf("expected %d certificates in JKS but found %d", expectedPEMCount, aliasCount)