From 4e21f87e3d014d606bb3ba2a89731a7d24806611 Mon Sep 17 00:00:00 2001 From: Tony Grosinger Date: Thu, 20 Apr 2017 08:04:18 -0700 Subject: [PATCH 1/2] pkg/transport: reload TLS certificates for every client requests This changes the baseConfig used when creating tls Configs to utilize the GetCertificate and GetClientCertificate functions to always reload the certificates from disk whenever they are needed. Always reloading the certificates allows changing the certificates via an external process without interrupting etcd. Fixes #7576 Cherry-picked by Gyu-Ho Lee Original commit can be found at https://github.com/coreos/etcd/pull/7784 --- pkg/transport/listener.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pkg/transport/listener.go b/pkg/transport/listener.go index e024f3c6bf4..76b36d94428 100644 --- a/pkg/transport/listener.go +++ b/pkg/transport/listener.go @@ -172,6 +172,14 @@ func (info TLSInfo) baseConfig() (*tls.Config, error) { MinVersion: tls.VersionTLS12, ServerName: info.ServerName, } + // this only reloads certs when there's a client request + // TODO: support server-side refresh (e.g. inotify, SIGHUP), caching + cfg.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + return tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc) + } + cfg.GetClientCertificate = func(unused *tls.CertificateRequestInfo) (*tls.Certificate, error) { + return tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc) + } return cfg, nil } From 22943e7e06910c0e5269b6260e04cf6f5ff67739 Mon Sep 17 00:00:00 2001 From: Gyu-Ho Lee Date: Thu, 27 Apr 2017 06:01:10 -0700 Subject: [PATCH 2/2] integration: test TLS reload Signed-off-by: Gyu-Ho Lee --- integration/cluster.go | 7 ++ integration/util_test.go | 62 +++++++++++ integration/v3_grpc_test.go | 205 ++++++++++++++++++++++++++++++++++++ 3 files changed, 274 insertions(+) create mode 100644 integration/util_test.go diff --git a/integration/cluster.go b/integration/cluster.go index fd7330a8533..7af9d77a1a3 100644 --- a/integration/cluster.go +++ b/integration/cluster.go @@ -76,6 +76,13 @@ var ( ClientCertAuth: true, } + testTLSInfoExpired = transport.TLSInfo{ + KeyFile: "./fixtures-expired/server-key.pem", + CertFile: "./fixtures-expired/server.pem", + TrustedCAFile: "./fixtures-expired/etcd-root-ca.pem", + ClientCertAuth: true, + } + plog = capnslog.NewPackageLogger("github.com/coreos/etcd", "integration") ) diff --git a/integration/util_test.go b/integration/util_test.go new file mode 100644 index 00000000000..18894198016 --- /dev/null +++ b/integration/util_test.go @@ -0,0 +1,62 @@ +// Copyright 2017 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "io" + "os" + "path/filepath" + + "github.com/coreos/etcd/pkg/transport" +) + +// copyTLSFiles clones certs files to dst directory. +func copyTLSFiles(ti transport.TLSInfo, dst string) (transport.TLSInfo, error) { + ci := transport.TLSInfo{ + KeyFile: filepath.Join(dst, "server-key.pem"), + CertFile: filepath.Join(dst, "server.pem"), + TrustedCAFile: filepath.Join(dst, "etcd-root-ca.pem"), + ClientCertAuth: ti.ClientCertAuth, + } + if err := copyFile(ti.KeyFile, ci.KeyFile); err != nil { + return transport.TLSInfo{}, err + } + if err := copyFile(ti.CertFile, ci.CertFile); err != nil { + return transport.TLSInfo{}, err + } + if err := copyFile(ti.TrustedCAFile, ci.TrustedCAFile); err != nil { + return transport.TLSInfo{}, err + } + return ci, nil +} + +func copyFile(src, dst string) error { + f, err := os.Open(src) + if err != nil { + return err + } + defer f.Close() + + w, err := os.Create(dst) + if err != nil { + return err + } + defer w.Close() + + if _, err = io.Copy(w, f); err != nil { + return err + } + return w.Sync() +} diff --git a/integration/v3_grpc_test.go b/integration/v3_grpc_test.go index 5113821def1..6ebc82d23c4 100644 --- a/integration/v3_grpc_test.go +++ b/integration/v3_grpc_test.go @@ -16,17 +16,21 @@ package integration import ( "bytes" + "crypto/tls" "fmt" + "io/ioutil" "math/rand" "os" "reflect" "testing" "time" + "github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/etcdserver/api/v3rpc" "github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes" pb "github.com/coreos/etcd/etcdserver/etcdserverpb" "github.com/coreos/etcd/pkg/testutil" + "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/metadata" @@ -1374,6 +1378,207 @@ func TestTLSGRPCAcceptSecureAll(t *testing.T) { } } +// TestTLSReloadAtomicReplace ensures server reloads expired/valid certs +// when all certs are atomically replaced by directory renaming. +// And expects server to reject client requests, and vice versa. +func TestTLSReloadAtomicReplace(t *testing.T) { + defer testutil.AfterTest(t) + + // clone valid,expired certs to separate directories for atomic renaming + vDir, err := ioutil.TempDir(os.TempDir(), "fixtures-valid") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(vDir) + ts, err := copyTLSFiles(testTLSInfo, vDir) + if err != nil { + t.Fatal(err) + } + eDir, err := ioutil.TempDir(os.TempDir(), "fixtures-expired") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(eDir) + if _, err = copyTLSFiles(testTLSInfoExpired, eDir); err != nil { + t.Fatal(err) + } + + tDir, err := ioutil.TempDir(os.TempDir(), "fixtures") + if err != nil { + t.Fatal(err) + } + os.RemoveAll(tDir) + defer os.RemoveAll(tDir) + + // start with valid certs + clus := NewClusterV3(t, &ClusterConfig{Size: 1, PeerTLS: &ts, ClientTLS: &ts}) + defer clus.Terminate(t) + + // concurrent client dialing while certs transition from valid to expired + errc := make(chan error, 1) + go func() { + for { + cc, err := ts.ClientConfig() + if err != nil { + if os.IsNotExist(err) { + // from concurrent renaming + continue + } + t.Fatal(err) + } + cli, cerr := clientv3.New(clientv3.Config{ + Endpoints: []string{clus.Members[0].GRPCAddr()}, + DialTimeout: time.Second, + TLS: cc, + }) + if cerr != nil { + errc <- cerr + return + } + cli.Close() + } + }() + + // replace certs directory with expired ones + if err = os.Rename(vDir, tDir); err != nil { + t.Fatal(err) + } + if err = os.Rename(eDir, vDir); err != nil { + t.Fatal(err) + } + + // after rename, + // 'vDir' contains expired certs + // 'tDir' contains valid certs + // 'eDir' does not exist + + select { + case err = <-errc: + if err != grpc.ErrClientConnTimeout { + t.Fatalf("expected %v, got %v", grpc.ErrClientConnTimeout, err) + } + case <-time.After(5 * time.Second): + t.Fatal("failed to receive dial timeout error") + } + + // now, replace expired certs back with valid ones + if err = os.Rename(tDir, eDir); err != nil { + t.Fatal(err) + } + if err = os.Rename(vDir, tDir); err != nil { + t.Fatal(err) + } + if err = os.Rename(eDir, vDir); err != nil { + t.Fatal(err) + } + + // new incoming client request should trigger + // listener to reload valid certs + var tls *tls.Config + tls, err = ts.ClientConfig() + if err != nil { + t.Fatal(err) + } + var cl *clientv3.Client + cl, err = clientv3.New(clientv3.Config{ + Endpoints: []string{clus.Members[0].GRPCAddr()}, + DialTimeout: time.Second, + TLS: tls, + }) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + cl.Close() +} + +// TestTLSReloadCopy ensures server reloads expired/valid certs +// when new certs are copied over, one by one. And expects server +// to reject client requests, and vice versa. +func TestTLSReloadCopy(t *testing.T) { + defer testutil.AfterTest(t) + + // clone certs directory, free to overwrite + cDir, err := ioutil.TempDir(os.TempDir(), "fixtures-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(cDir) + ts, err := copyTLSFiles(testTLSInfo, cDir) + if err != nil { + t.Fatal(err) + } + + // start with valid certs + clus := NewClusterV3(t, &ClusterConfig{Size: 1, PeerTLS: &ts, ClientTLS: &ts}) + defer clus.Terminate(t) + + // concurrent client dialing while certs transition from valid to expired + errc := make(chan error, 1) + go func() { + for { + cc, err := ts.ClientConfig() + if err != nil { + // from concurrent certs overwriting + switch err.Error() { + case "tls: private key does not match public key": + fallthrough + case "tls: failed to find any PEM data in key input": + continue + } + t.Fatal(err) + } + cli, cerr := clientv3.New(clientv3.Config{ + Endpoints: []string{clus.Members[0].GRPCAddr()}, + DialTimeout: time.Second, + TLS: cc, + }) + if cerr != nil { + errc <- cerr + return + } + cli.Close() + } + }() + + // overwrite valid certs with expired ones + // (e.g. simulate cert expiration in practice) + if _, err = copyTLSFiles(testTLSInfoExpired, cDir); err != nil { + t.Fatal(err) + } + + select { + case gerr := <-errc: + if gerr != grpc.ErrClientConnTimeout { + t.Fatalf("expected %v, got %v", grpc.ErrClientConnTimeout, gerr) + } + case <-time.After(5 * time.Second): + t.Fatal("failed to receive dial timeout error") + } + + // now, replace expired certs back with valid ones + if _, err = copyTLSFiles(testTLSInfo, cDir); err != nil { + t.Fatal(err) + } + + // new incoming client request should trigger + // listener to reload valid certs + var tls *tls.Config + tls, err = ts.ClientConfig() + if err != nil { + t.Fatal(err) + } + var cl *clientv3.Client + cl, err = clientv3.New(clientv3.Config{ + Endpoints: []string{clus.Members[0].GRPCAddr()}, + DialTimeout: time.Second, + TLS: tls, + }) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + cl.Close() +} + func TestGRPCRequireLeader(t *testing.T) { defer testutil.AfterTest(t)