Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ca.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
"certlifetime": 30,
"metrics": {
"enabled": true,
"port": 9123
"port": 9123,
"dataSource": "/opt/acme-proxy/db/metrics"
}
},
"provisioners": [
Expand Down
113 changes: 105 additions & 8 deletions externalcas/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"fmt"
"log/slog"
"net/http"
"strings"
"time"

"github.com/go-acme/lego/v4/certcrypto"
Expand All @@ -29,7 +30,15 @@ func init() {
}

func New(ctx context.Context, opts apiv1.Options) (*ExternalCAS, error) {
return &ExternalCAS{ctx: ctx, config: opts.Config}, nil
cas := &ExternalCAS{ctx: ctx, config: opts.Config}
cfg, err := cas.parseConfig()
if err != nil {
return nil, err
}
if err := StartMetricsServer(cfg.Metrics); err != nil {
return nil, err
}
return cas, nil
}

// AcmeProxyConfig contains the configuration for connecting to an external ACME CA
Expand Down Expand Up @@ -63,6 +72,9 @@ func (c *AcmeProxyConfig) Validate() error {
if c.CertLifetime < 0 {
return errors.New("certlifetime cannot be negative")
}
if c.Metrics.Enabled && c.Metrics.DataSource == "" {
return errors.New("metrics.datasource is required when metrics is enabled")
}
return nil
}

Expand All @@ -77,8 +89,9 @@ func (c *AcmeProxyConfig) RequestTimeout() time.Duration {
}

type Metrics struct {
Enabled bool `json:"enabled,omitempty"`
Port int `json:"port,omitempty"`
Enabled bool `json:"enabled,omitempty"`
Port int `json:"port,omitempty"`
DataSource string `json:"datasource,omitempty"`
}

// User implements the lego registration.User interface
Expand Down Expand Up @@ -159,6 +172,7 @@ func splitCertificateBundle(pemBytes []byte) (*x509.Certificate, []*x509.Certifi
// certificateResult holds the result of an async certificate operation
type certificateResult struct {
response *apiv1.CreateCertificateResponse
duration time.Duration // time ObtainForCSR took; used for metrics
err error
}

Expand Down Expand Up @@ -273,18 +287,23 @@ func (c *ExternalCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*a
}
}()

start := time.Now()
cert, err := acmeClient.ObtainForCSR(csrRequest)
duration := time.Since(start)

if err != nil {
resultChan <- &certificateResult{
err: fmt.Errorf("failed to obtain certificate: %w", err),
err: fmt.Errorf("failed to obtain certificate: %w", err),
duration: duration,
}
return
}

leaf, intermediates, err := splitCertificateBundle(cert.Certificate)
if err != nil {
resultChan <- &certificateResult{
err: fmt.Errorf("failed to split certificate bundle: %w", err),
err: fmt.Errorf("failed to split certificate bundle: %w", err),
duration: duration,
}
return
}
Expand All @@ -294,17 +313,59 @@ func (c *ExternalCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*a
Certificate: leaf,
CertificateChain: intermediates,
},
duration: duration,
}
}()

select {
case result := <-resultChan:
if result.err != nil {
if metricsEnabled {
certificatesIssuedTotal.WithLabelValues("failure").Inc()
if req.CSR != nil {
if err := globalStore.recordIssued(CertRecord{
CommonName: req.CSR.Subject.CommonName,
SANs: strings.Join(req.CSR.DNSNames, ","),
DurationSeconds: result.duration.Seconds(),
Status: "failure",
}); err != nil {
slog.Error("failed to record cert issuance failure", "error", err)
}
}
}
return nil, result.err
}
slog.Info("obtained certificate from external CA", "domains", req.CSR.DNSNames)
if metricsEnabled {
certificatesIssuedTotal.WithLabelValues("success").Inc()
cert := result.response.Certificate
if err := globalStore.recordIssued(CertRecord{
Serial: cert.SerialNumber.Text(16),
CommonName: cert.Subject.CommonName,
Issuer: cert.Issuer.CommonName,
SANs: strings.Join(cert.DNSNames, ","),
IssuedAt: cert.NotBefore,
ExpiresAt: cert.NotAfter,
DurationSeconds: result.duration.Seconds(),
Status: "success",
}); err != nil {
slog.Error("failed to record cert issuance", "error", err)
}
}
return result.response, nil
case <-ctx.Done():
if metricsEnabled {
certificatesIssuedTotal.WithLabelValues("failure").Inc()
if req.CSR != nil {
if err := globalStore.recordIssued(CertRecord{
CommonName: req.CSR.Subject.CommonName,
SANs: strings.Join(req.CSR.DNSNames, ","),
Status: "failure",
}); err != nil {
slog.Error("failed to record cert issuance timeout", "error", err)
}
}
}
return nil, fmt.Errorf("certificate request timed out: %w", ctx.Err())
}
}
Expand Down Expand Up @@ -344,17 +405,53 @@ func (c *ExternalCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*a
"subject", req.Certificate.Subject.CommonName,
)

if err := acmeClient.Revoke(pemBytes); err != nil {
revokeStart := time.Now()
revokeErr := acmeClient.Revoke(pemBytes)
revokeDuration := time.Since(revokeStart)

if revokeErr != nil {
slog.Error("failed to revoke certificate",
"serial", req.Certificate.SerialNumber.String(),
"error", err,
"error", revokeErr,
)
return nil, fmt.Errorf("failed to revoke certificate: %w", err)
if metricsEnabled {
certificatesRevokedTotal.WithLabelValues("failure").Inc()
cert := req.Certificate
if err := globalStore.recordRevoked(CertRecord{
Serial: cert.SerialNumber.Text(16),
CommonName: cert.Subject.CommonName,
Issuer: cert.Issuer.CommonName,
SANs: strings.Join(cert.DNSNames, ","),
IssuedAt: cert.NotBefore,
ExpiresAt: cert.NotAfter,
DurationSeconds: revokeDuration.Seconds(),
Status: "failure",
}); err != nil {
slog.Error("failed to record cert revocation failure", "error", err)
}
}
return nil, fmt.Errorf("failed to revoke certificate: %w", revokeErr)
}

slog.Info("certificate revoked successfully",
"serial", req.Certificate.SerialNumber.String(),
)
if metricsEnabled {
certificatesRevokedTotal.WithLabelValues("success").Inc()
cert := req.Certificate
if err := globalStore.recordRevoked(CertRecord{
Serial: cert.SerialNumber.Text(16),
CommonName: cert.Subject.CommonName,
Issuer: cert.Issuer.CommonName,
SANs: strings.Join(cert.DNSNames, ","),
IssuedAt: cert.NotBefore,
ExpiresAt: cert.NotAfter,
DurationSeconds: revokeDuration.Seconds(),
Status: "success",
}); err != nil {
slog.Error("failed to record cert revocation", "error", err)
}
}

return &apiv1.RevokeCertificateResponse{
Certificate: req.Certificate,
Expand Down
106 changes: 71 additions & 35 deletions externalcas/external_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ import (
func TestNew(t *testing.T) {
ctx := context.Background()
opts := apiv1.Options{
Type: "externalcas",
IsCreator: false,
IsCAGetter: false,
Config: []byte(""),
Type: "externalcas",
Config: mustMarshalConfig(t, &AcmeProxyConfig{
CaURL: "https://acme.example.com",
Kid: "test-kid",
HmacKey: "test-hmac",
}),
}

cas, err := New(ctx, opts)
Expand All @@ -42,6 +44,46 @@ func TestNew(t *testing.T) {
}
}

func TestNew_ValidatesConfig(t *testing.T) {
tests := []struct {
name string
config []byte
errMsg string
}{
{
name: "empty config",
config: []byte(""),
errMsg: "failed to unmarshal config",
},
{
name: "missing ca_url",
config: mustMarshalConfig(t, &AcmeProxyConfig{Kid: "k", HmacKey: "h"}),
errMsg: "ca_url is required",
},
{
name: "metrics enabled without datasource",
config: mustMarshalConfig(t, &AcmeProxyConfig{
CaURL: "https://acme.example.com",
Kid: "test-kid",
HmacKey: "test-hmac",
Metrics: Metrics{Enabled: true, DataSource: ""},
}),
errMsg: "metrics.datasource is required",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := New(context.Background(), apiv1.Options{Config: tt.config})
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("error = %q, want error containing %q", err.Error(), tt.errMsg)
}
})
}
}

func TestAcmeProxyConfig_Validate(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -113,6 +155,27 @@ func TestAcmeProxyConfig_Validate(t *testing.T) {
},
wantErr: false,
},
{
name: "metrics enabled without datasource",
config: AcmeProxyConfig{
CaURL: "https://acme.example.com",
Kid: "test-kid",
HmacKey: "test-hmac",
Metrics: Metrics{Enabled: true, DataSource: ""},
},
wantErr: true,
errMsg: "metrics.datasource is required",
},
{
name: "metrics enabled with datasource",
config: AcmeProxyConfig{
CaURL: "https://acme.example.com",
Kid: "test-kid",
HmacKey: "test-hmac",
Metrics: Metrics{Enabled: true, DataSource: "/tmp/test.db"},
},
wantErr: false,
},
}

for _, tt := range tests {
Expand Down Expand Up @@ -304,16 +367,7 @@ func Test_validateRevokeCertificateRequest(t *testing.T) {
}

func TestCreateCertificate_Validation(t *testing.T) {
ctx := context.Background()
opts := apiv1.Options{
Type: "externalcas",
Config: []byte("{}"),
}

extcas, err := New(ctx, opts)
if err != nil {
t.Fatal(err)
}
extcas := &ExternalCAS{ctx: context.Background()}

tests := []struct {
name string
Expand Down Expand Up @@ -562,16 +616,7 @@ func Test_splitCertificateBundle(t *testing.T) {
}

func TestRevokeCertificate_Validation(t *testing.T) {
ctx := context.Background()
opts := apiv1.Options{
Type: "externalcas",
Config: []byte("{}"),
}

extcas, err := New(ctx, opts)
if err != nil {
t.Fatal(err)
}
extcas := &ExternalCAS{ctx: context.Background()}

tests := []struct {
name string
Expand Down Expand Up @@ -606,18 +651,9 @@ func TestRevokeCertificate_Validation(t *testing.T) {
}

func TestRenewCertificate_NotImplemented(t *testing.T) {
ctx := context.Background()
opts := apiv1.Options{
Type: "externalcas",
Config: []byte("{}"),
}

cas, err := New(ctx, opts)
if err != nil {
t.Fatal(err)
}
cas := &ExternalCAS{ctx: context.Background()}

_, err = cas.RenewCertificate(&apiv1.RenewCertificateRequest{})
_, err := cas.RenewCertificate(&apiv1.RenewCertificateRequest{})
if err == nil {
t.Fatal("expected NotImplementedError, got nil")
}
Expand Down
Loading
Loading