From f6b55976bb5ca18c479b7c56e69ba812e2bdffdf Mon Sep 17 00:00:00 2001 From: Jaime Soriano Pastor Date: Wed, 15 Jun 2022 10:41:18 +0200 Subject: [PATCH] Add helpers to issue certificates --- internal/certs/certs.go | 315 +++++++++++++++++++++++++++++++++++ internal/certs/certs_test.go | 157 +++++++++++++++++ 2 files changed, 472 insertions(+) create mode 100644 internal/certs/certs.go create mode 100644 internal/certs/certs_test.go diff --git a/internal/certs/certs.go b/internal/certs/certs.go new file mode 100644 index 0000000000..9b5f5a067a --- /dev/null +++ b/internal/certs/certs.go @@ -0,0 +1,315 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package certs + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "io" + "math/big" + "net" + "os" + "path/filepath" + "time" + + "github.com/elastic/elastic-package/internal/common" +) + +// Certificate contains the key and certificate for an issued certificate. +type Certificate struct { + key crypto.Signer + cert *x509.Certificate + issuer *Certificate +} + +// LoadCertificate loads a certificate and key from disk. +func LoadCertificate(certFile, keyFile string) (*Certificate, error) { + pair, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + if len(pair.Certificate) == 0 { + return nil, fmt.Errorf("no certificates in %q?", certFile) + } + chain := make([]*x509.Certificate, len(pair.Certificate)) + for i := range pair.Certificate { + cert, err := x509.ParseCertificate(pair.Certificate[i]) + if err != nil { + return nil, fmt.Errorf("failed to parse #%d certificate loaded from %q", i, certFile) + } + chain[i] = cert + } + + var key crypto.Signer + switch privKey := pair.PrivateKey.(type) { + case crypto.Signer: + key = privKey + default: + return nil, fmt.Errorf("key of type %T cannot be used", privKey) + } + + cert := &Certificate{ + key: key, + cert: chain[0], + } + + if len(chain) > 1 { + // This is an intermediate certificate, rebuild the full chain. + c := cert + for _, cert := range chain[1:] { + c.issuer = &Certificate{ + // Parent keys are not known here, but that's ok + // as these certs are only used for the cert chain. + cert: cert, + } + c = c.issuer + } + } + return cert, nil +} + +// Issuer is a certificate that can issue other certificates. +type Issuer struct { + *Certificate +} + +// NewCA creates a new self-signed root CA. +func NewCA() (*Issuer, error) { + return newCA(nil) +} + +// LoadCA loads a CA certificate and key from disk. +func LoadCA(certFile, keyFile string) (*Issuer, error) { + cert, err := LoadCertificate(certFile, keyFile) + if err != nil { + return nil, err + } + + return &Issuer{cert}, nil +} + +func newCA(parent *Issuer) (*Issuer, error) { + cert, err := New(true, parent) + if err != nil { + return nil, err + } + return &Issuer{Certificate: cert}, nil +} + +// IssueIntermediate issues an intermediate CA signed by the issuer. +func (i *Issuer) IssueIntermediate() (*Issuer, error) { + return newCA(i) +} + +// Issue issues a certificate with the given options. This certificate +// can be used to configure a TLS server. +func (i *Issuer) Issue(opts ...Option) (*Certificate, error) { + return New(false, i, opts...) +} + +// NewSelfSignedCert issues a self-signed certificate with the given options. +// This certificate can be used to configure a TLS server. +func NewSelfSignedCert(opts ...Option) (*Certificate, error) { + return New(false, nil, opts...) +} + +// Option is a function that can modify a certificate template. To be used +// when issuing certificates. +type Option func(template *x509.Certificate) + +// WithName is an option to configure the common and alternate DNS names of a certificate. +func WithName(name string) Option { + return func(template *x509.Certificate) { + template.Subject.CommonName = name + if !common.StringSliceContains(template.DNSNames, name) { + template.DNSNames = append(template.DNSNames, name) + } + } +} + +// New is the main helper to create a certificate, it is recommended to +// use the more specific ones for specific use cases. +func New(isCA bool, issuer *Issuer, opts ...Option) (*Certificate, error) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, fmt.Errorf("failed to generate key: %w", err) + } + + sn, err := newSerialNumber() + if err != nil { + return nil, fmt.Errorf("failed to get a unique serial number: %w", err) + } + + const longTime = 100 * 24 * 365 * time.Hour + template := x509.Certificate{ + NotBefore: time.Now(), + NotAfter: time.Now().Add(longTime), + + SerialNumber: sn, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + } + + if isCA { + template.IsCA = true + template.KeyUsage |= x509.KeyUsageCRLSign | x509.KeyUsageCertSign + + if issuer == nil { + template.Subject.CommonName = "elastic-package CA" + } else { + template.Subject.CommonName = "intermediate elastic-package CA" + } + } else { + // Include local hostname and ips as alternates in service certificates. + template.DNSNames = []string{"localhost"} + template.IPAddresses = []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")} + } + + for _, opt := range opts { + opt(&template) + } + + // Self-signed unless an issuer has been received. + var parent *x509.Certificate = &template + var signer crypto.Signer = key + var issuerCert *Certificate + if issuer != nil { + parent = issuer.cert + signer = issuer.key + issuerCert = issuer.Certificate + template.Issuer = issuer.cert.Subject + } + + der, err := x509.CreateCertificate(rand.Reader, &template, parent, key.Public(), signer) + if err != nil { + return nil, fmt.Errorf("failed to generate certificate: %w", err) + } + cert, err := x509.ParseCertificate(der) + if err != nil { + return nil, fmt.Errorf("failed to parse certificate: %w", err) + } + + return &Certificate{ + key: key, + cert: cert, + issuer: issuerCert, + }, nil +} + +func newSerialNumber() (*big.Int, error) { + // This implementation attempts to get unique serial numbers + // by getting random ones between 0 and 2^128. + max := new(big.Int).Exp(big.NewInt(2), big.NewInt(128), nil) + return rand.Int(rand.Reader, max) +} + +// WriteKey writes the PEM-encoded key in the given writer. +func (c *Certificate) WriteKey(w io.Writer) error { + keyPem, err := keyPemBlock(c.key) + if err != nil { + return fmt.Errorf("failed to encode key PEM block: %w", err) + } + + return encodePem(w, keyPem) +} + +// WriteKeyFile writes the PEM-encoded key in the given file. +func (c *Certificate) WriteKeyFile(path string) error { + err := os.MkdirAll(filepath.Dir(path), 0755) + if err != nil { + return fmt.Errorf("error creating directory for key file: %w", err) + } + f, err := os.Create(path) + if err != nil { + return fmt.Errorf("failed to create key file %q: %w", path, err) + } + defer f.Close() + + return c.WriteKey(f) +} + +// WriteCert writes the PEM-encoded certificate chain in the given writer. +func (c *Certificate) WriteCert(w io.Writer) error { + for i := c; i != nil; i = i.issuer { + err := encodePem(w, certPemBlock(i.cert.Raw)) + if err != nil { + return err + } + } + + return nil +} + +// WriteCertFile writes the PEM-encoded certificate in the given file. +func (c *Certificate) WriteCertFile(path string) error { + err := os.MkdirAll(filepath.Dir(path), 0755) + if err != nil { + return fmt.Errorf("error creating directory for certificate file: %w", err) + } + f, err := os.Create(path) + if err != nil { + return fmt.Errorf("failed to create cert file %q: %w", path, err) + } + defer f.Close() + + return c.WriteCert(f) +} + +// Verify verifies a certificate with the given verification options. +func (c *Certificate) Verify(options x509.VerifyOptions) error { + _, err := c.cert.Verify(options) + return err +} + +func certPemBlock(cert []byte) *pem.Block { + const certificatePemType = "CERTIFICATE" + return &pem.Block{ + Type: certificatePemType, + Bytes: cert, + } +} + +func keyPemBlock(key crypto.Signer) (*pem.Block, error) { + const ( + ecPrivateKeyPemType = "EC PRIVATE KEY" + rsaPrivateKeyPemType = "RSA PRIVATE KEY" + ) + switch key := key.(type) { + case *rsa.PrivateKey: + d := x509.MarshalPKCS1PrivateKey(key) + return &pem.Block{ + Type: rsaPrivateKeyPemType, + Bytes: d, + }, nil + case *ecdsa.PrivateKey: + d, err := x509.MarshalECPrivateKey(key) + if err != nil { + return nil, fmt.Errorf("failed to encode EC private key: %w", err) + } + return &pem.Block{ + Type: ecPrivateKeyPemType, + Bytes: d, + }, nil + default: + return nil, fmt.Errorf("unsupported key type %T", key) + } +} + +func encodePem(w io.Writer, blocks ...*pem.Block) error { + for _, block := range blocks { + err := pem.Encode(w, block) + if err != nil { + return err + } + } + return nil +} diff --git a/internal/certs/certs_test.go b/internal/certs/certs_test.go new file mode 100644 index 0000000000..6bceea059f --- /dev/null +++ b/internal/certs/certs_test.go @@ -0,0 +1,157 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package certs + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "net" + "net/http" + "os/exec" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSelfSignedCertificate(t *testing.T) { + const commonName = "someserver" + cert, err := NewSelfSignedCert(WithName(commonName)) + require.NoError(t, err) + + address := testTLSServer(t, cert) + testTLSClient(t, cert, commonName, address) +} + +func TestCA(t *testing.T) { + ca, err := NewCA() + require.NoError(t, err) + + intermediate, err := ca.IssueIntermediate() + require.NoError(t, err) + + const commonName = "elasticsearch" + cert, err := intermediate.Issue(WithName(commonName)) + require.NoError(t, err) + + t.Run("validate server with root CA", func(t *testing.T) { + address := testTLSServer(t, cert) + t.Run("go-http client", func(t *testing.T) { + testTLSClient(t, ca.Certificate, commonName, address) + }) + t.Run("curl", func(t *testing.T) { + testCurl(t, ca.Certificate, commonName, address) + }) + }) + + t.Run("validate server with intermediate CA", func(t *testing.T) { + address := testTLSServer(t, cert) + t.Run("go-http client", func(t *testing.T) { + testTLSClient(t, intermediate.Certificate, commonName, address) + }) + t.Run("curl", func(t *testing.T) { + testCurl(t, intermediate.Certificate, commonName, address) + }) + }) +} + +func testTLSServer(t *testing.T, cert *Certificate) string { + tmpDir := t.TempDir() + keyFile := filepath.Join(tmpDir, "cert.key") + certFile := filepath.Join(tmpDir, "cert.pem") + + err := cert.WriteKeyFile(keyFile) + require.NoError(t, err) + + err = cert.WriteCertFile(certFile) + require.NoError(t, err) + + listener, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + t.Cleanup(func() { listener.Close() }) + + go func() { + server := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "ok") + }), + } + server.ServeTLS(listener, certFile, keyFile) + }() + + return listener.Addr().String() +} + +func testTLSClient(t *testing.T, root *Certificate, commonName, address string) { + caPool := x509.NewCertPool() + caPool.AddCert(root.cert) + client := &http.Client{ + Transport: &http.Transport{ + // Send all requests to the listener address. + DialContext: func(ctx context.Context, network, _ string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, network, address) + }, + DialTLSContext: func(ctx context.Context, network, reqAddress string) (net.Conn, error) { + var d tls.Dialer + host, _, _ := net.SplitHostPort(reqAddress) + d.Config = &tls.Config{ + ServerName: host, + RootCAs: caPool, + } + return d.DialContext(ctx, network, address) + }, + }, + } + + resp, err := client.Get("https://" + commonName) + require.NoError(t, err) + defer resp.Body.Close() + d, _ := ioutil.ReadAll(resp.Body) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "ok", string(d)) + +} + +func testCurl(t *testing.T, root *Certificate, commonName, address string) { + _, err := exec.LookPath("curl") + if err != nil { + t.Skip("curl not available") + } + + caCert := filepath.Join(t.TempDir(), "ca-cert.pem") + err = root.WriteCertFile(caCert) + require.NoError(t, err) + + serverHost, port, err := net.SplitHostPort(address) + require.NoError(t, err) + require.NotNilf(t, net.ParseIP(serverHost), "%s expected to be an ip", serverHost) + + // Address to use in the request, hostname here must match name in certificate. + reqAddress := net.JoinHostPort(commonName, port) + + args := []string{ + "-v", + "--cacert", caCert, + // Send requests to the listener address. + "--resolve", reqAddress + ":" + serverHost, + "https://" + reqAddress, + } + + var buf bytes.Buffer + cmd := exec.Command("curl", args...) + cmd.Stderr = &buf + cmd.Stdout = &buf + + err = cmd.Run() + if !assert.NoError(t, err) { + t.Log(buf.String()) + } +}