Skip to content

Commit

Permalink
chore: encapsulate tls generation
Browse files Browse the repository at this point in the history
Signed-off-by: Armando Ruocco <armando.ruocco@enterprisedb.com>
  • Loading branch information
armru committed May 21, 2024
1 parent 3d4dd90 commit 931b2f5
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 27 deletions.
6 changes: 6 additions & 0 deletions api/v1/cluster_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"

"github.com/cloudnative-pg/cloudnative-pg/internal/configuration"
"github.com/cloudnative-pg/cloudnative-pg/pkg/management/log"
Expand Down Expand Up @@ -3438,6 +3439,11 @@ func (cluster *Cluster) GetTablespaceConfiguration(name string) *TablespaceConfi
return nil
}

// GetServerCASecretObjectKey returns a types.NamespacedName pointing to the secret
func (cluster *Cluster) GetServerCASecretObjectKey() types.NamespacedName {
return types.NamespacedName{Namespace: cluster.Namespace, Name: cluster.GetServerCASecretName()}
}

// IsBarmanBackupConfigured returns true if one of the possible backup destination
// is configured, false otherwise
func (backupConfiguration *BackupConfiguration) IsBarmanBackupConfigured() bool {
Expand Down
8 changes: 7 additions & 1 deletion controllers/backup_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/manager"

apiv1 "github.com/cloudnative-pg/cloudnative-pg/api/v1"
"github.com/cloudnative-pg/cloudnative-pg/pkg/certs"
"github.com/cloudnative-pg/cloudnative-pg/pkg/conditions"
"github.com/cloudnative-pg/cloudnative-pg/pkg/management/log"
"github.com/cloudnative-pg/cloudnative-pg/pkg/management/postgres"
Expand Down Expand Up @@ -143,7 +144,12 @@ func (r *BackupReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctr
contextLogger.Debug("Found cluster for backup", "cluster", clusterName)

// Store in the context the TLS configuration required communicating with the Pods
tlsConfig, err := newTLSConfigFromCluster(ctx, &cluster, r.Client)
tlsConfig, err := certs.NewTLSFromSecret(
ctx,
r.Client,
cluster.GetServerCASecretObjectKey(),
cluster.GetServiceReadWriteName(),
)
if err != nil {
return ctrl.Result{}, err
}
Expand Down
32 changes: 6 additions & 26 deletions controllers/cluster_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ package controllers

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"reflect"
Expand Down Expand Up @@ -252,7 +250,12 @@ func (r *ClusterReconciler) reconcile(ctx context.Context, cluster *apiv1.Cluste
}

// Store in the context the TLS configuration required communicating with the Pods
tlsConfig, err := newTLSConfigFromCluster(ctx, cluster, r.Client)
tlsConfig, err := certs.NewTLSFromSecret(
ctx,
r.Client,
cluster.GetServerCASecretObjectKey(),
cluster.GetServiceReadWriteName(),
)
if err != nil {
return ctrl.Result{}, err
}
Expand Down Expand Up @@ -404,29 +407,6 @@ func (r *ClusterReconciler) reconcile(ctx context.Context, cluster *apiv1.Cluste
return hookResult.Result, hookResult.Err
}

func newTLSConfigFromCluster(ctx context.Context, cluster *apiv1.Cluster, c client.Client) (*tls.Config, error) {
secret := &corev1.Secret{}
err := c.Get(ctx, client.ObjectKey{Namespace: cluster.Namespace, Name: cluster.GetServerCASecretName()}, secret)
if err != nil {
return nil, fmt.Errorf("while getting secret %s: %w", cluster.GetServerCASecretName(), err)
}

caCertificate, ok := secret.Data[certs.CACertKey]
if !ok {
return nil, fmt.Errorf("missing %s entry in secret %s", certs.CACertKey, cluster.GetServerCASecretName())
}

caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCertificate)
tlsConfig := tls.Config{
MinVersion: tls.VersionTLS13,
ServerName: cluster.GetServiceReadWriteName(),
RootCAs: caCertPool,
}

return &tlsConfig, nil
}

func (r *ClusterReconciler) handleSwitchover(
ctx context.Context,
cluster *apiv1.Cluster,
Expand Down
57 changes: 57 additions & 0 deletions pkg/certs/tls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
Copyright The CloudNativePG Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package certs

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"

v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/client"
)

// NewTLSFromSecret creates a tls.Config from the given ca secret and serverName pair
func NewTLSFromSecret(
ctx context.Context,
c client.Client,
caSecret types.NamespacedName,
serverName string,
) (*tls.Config, error) {
secret := &v1.Secret{}
err := c.Get(ctx, caSecret, secret)
if err != nil {
return nil, fmt.Errorf("while getting secret %s: %w", caSecret.Name, err)
}

caCertificate, ok := secret.Data[CACertKey]
if !ok {
return nil, fmt.Errorf("missing %s entry in secret %s", CACertKey, caSecret.Name)
}

caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCertificate)
tlsConfig := tls.Config{
MinVersion: tls.VersionTLS13,
ServerName: serverName,
RootCAs: caCertPool,
}

return &tlsConfig, nil
}
107 changes: 107 additions & 0 deletions pkg/certs/tls_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
Copyright The CloudNativePG Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package certs

import (
"context"
"crypto/tls"
"fmt"

v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/client/fake"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

var _ = Describe("NewTLSFromSecret", func() {
var (
ctx context.Context
c client.Client
caSecret types.NamespacedName
serverName string
)

BeforeEach(func() {
ctx = context.TODO()
caSecret = types.NamespacedName{Name: "test-secret", Namespace: "default"}
serverName = "test-server"
})

Context("when the secret is found and valid", func() {
BeforeEach(func() {
secretData := map[string][]byte{
CACertKey: []byte(`-----BEGIN CERTIFICATE-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA7Qe3X7Q6WZpXqlXkq0Bd
... (rest of the CA certificate) ...
-----END CERTIFICATE-----`),
}
secret := &v1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: caSecret.Name,
Namespace: caSecret.Namespace,
},
Data: secretData,
}
c = fake.NewClientBuilder().WithObjects(secret).Build()
})

It("should return a valid tls.Config", func() {
tlsConfig, err := NewTLSFromSecret(ctx, c, caSecret, serverName)
Expect(err).NotTo(HaveOccurred())
Expect(tlsConfig).NotTo(BeNil())
Expect(tlsConfig.MinVersion).To(Equal(uint16(tls.VersionTLS13)))
Expect(tlsConfig.ServerName).To(Equal(serverName))
Expect(tlsConfig.RootCAs).ToNot(BeNil())
})
})

Context("when the secret is not found", func() {
BeforeEach(func() {
c = fake.NewClientBuilder().Build()
})

It("should return an error", func() {
tlsConfig, err := NewTLSFromSecret(ctx, c, caSecret, serverName)
Expect(err).To(HaveOccurred())
Expect(tlsConfig).To(BeNil())
Expect(err.Error()).To(ContainSubstring(fmt.Sprintf("while getting secret %s", caSecret.Name)))
})
})

Context("when the ca.crt entry is missing in the secret", func() {
BeforeEach(func() {
secret := &v1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: caSecret.Name,
Namespace: caSecret.Namespace,
},
}
c = fake.NewClientBuilder().WithObjects(secret).Build()
})

It("should return an error", func() {
tlsConfig, err := NewTLSFromSecret(ctx, c, caSecret, serverName)
Expect(err).To(HaveOccurred())
Expect(tlsConfig).To(BeNil())
Expect(err.Error()).To(ContainSubstring(fmt.Sprintf("missing %s entry in secret %s", CACertKey, caSecret.Name)))
})
})
})

0 comments on commit 931b2f5

Please sign in to comment.